summaryrefslogtreecommitdiff
path: root/compiler
diff options
context:
space:
mode:
authorSon HO2023-11-10 18:21:06 +0100
committerGitHub2023-11-10 18:21:06 +0100
commit587f1ebc0178acb19029d3fc9a729c197082aba7 (patch)
treef29805e5426f9f3fabe12d3fdadda96a1e987880 /compiler
parent7fc7c82aa61d782b335e7cf37231fd9998cd0d89 (diff)
parentd300be95c28ff3147bb6f6a65992df5b9b571bdf (diff)
Merge pull request #44 from AeneasVerif/son_traits_types
Add support for traits
Diffstat (limited to '')
-rw-r--r--compiler/AssociatedTypes.ml681
-rw-r--r--compiler/Assumed.ml381
-rw-r--r--compiler/Config.ml31
-rw-r--r--compiler/Contexts.ml127
-rw-r--r--compiler/Driver.ml56
-rw-r--r--compiler/Extract.ml3424
-rw-r--r--compiler/ExtractBase.ml828
-rw-r--r--compiler/ExtractBuiltin.ml648
-rw-r--r--compiler/ExtractTypes.ml2477
-rw-r--r--compiler/FunsAnalysis.ml57
-rw-r--r--compiler/Interpreter.ml259
-rw-r--r--compiler/InterpreterBorrows.ml3
-rw-r--r--compiler/InterpreterBorrowsCore.ml16
-rw-r--r--compiler/InterpreterExpansion.ml59
-rw-r--r--compiler/InterpreterExpressions.ml195
-rw-r--r--compiler/InterpreterLoopsJoinCtxs.ml18
-rw-r--r--compiler/InterpreterLoopsMatchCtxs.ml19
-rw-r--r--compiler/InterpreterPaths.ml66
-rw-r--r--compiler/InterpreterPaths.mli11
-rw-r--r--compiler/InterpreterProjectors.ml15
-rw-r--r--compiler/InterpreterStatements.ml770
-rw-r--r--compiler/InterpreterStatements.mli9
-rw-r--r--compiler/InterpreterUtils.ml124
-rw-r--r--compiler/Invariants.ml82
-rw-r--r--compiler/LlbcAst.ml1
-rw-r--r--compiler/LlbcAstUtils.ml40
-rw-r--r--compiler/Logging.ml8
-rw-r--r--compiler/PrePasses.ml6
-rw-r--r--compiler/Print.ml184
-rw-r--r--compiler/PrintPure.ml272
-rw-r--r--compiler/Pure.ml194
-rw-r--r--compiler/PureMicroPasses.ml211
-rw-r--r--compiler/PureTypeCheck.ml62
-rw-r--r--compiler/PureUtils.ml190
-rw-r--r--compiler/ReorderDecls.ml8
-rw-r--r--compiler/Substitute.ml493
-rw-r--r--compiler/SymbolicAst.ml33
-rw-r--r--compiler/SymbolicToPure.ml830
-rw-r--r--compiler/SynthesizeSymbolic.ml38
-rw-r--r--compiler/Translate.ml672
-rw-r--r--compiler/TranslateCore.ml79
-rw-r--r--compiler/TypesAnalysis.ml36
-rw-r--r--compiler/Values.ml4
-rw-r--r--compiler/dune5
44 files changed, 8964 insertions, 4758 deletions
diff --git a/compiler/AssociatedTypes.ml b/compiler/AssociatedTypes.ml
new file mode 100644
index 00000000..581e218c
--- /dev/null
+++ b/compiler/AssociatedTypes.ml
@@ -0,0 +1,681 @@
+(** This file implements utilities to handle trait associated types, in
+ particular with normalization helpers.
+
+ When normalizing a type, we simplify the references to the trait associated
+ types, and choose a representative when there are equalities between types
+ enforced by local clauses (i.e., clauses of the shape [where Trait1::T = Trait2::U]).
+ *)
+
+module T = Types
+module TU = TypesUtils
+module V = Values
+module E = Expressions
+module A = LlbcAst
+module C = Contexts
+module Subst = Substitute
+module L = Logging
+module UF = UnionFind
+module PA = Print.EvalCtxLlbcAst
+
+(** The local logger *)
+let log = L.associated_types_log
+
+let trait_type_ref_substitute (subst : ('r, 'r1) Subst.subst)
+ (r : 'r C.trait_type_ref) : 'r1 C.trait_type_ref =
+ let { C.trait_ref; type_name } = r in
+ let trait_ref = Subst.trait_ref_substitute subst trait_ref in
+ { C.trait_ref; type_name }
+
+(* TODO: how not to duplicate below? *)
+module RTyOrd = struct
+ type t = T.rty
+
+ let compare = T.compare_rty
+ let to_string = T.show_rty
+ let pp_t = T.pp_rty
+ let show_t = T.show_rty
+end
+
+module STyOrd = struct
+ type t = T.sty
+
+ let compare = T.compare_sty
+ let to_string = T.show_sty
+ let pp_t = T.pp_sty
+ let show_t = T.show_sty
+end
+
+module RTyMap = Collections.MakeMap (RTyOrd)
+module STyMap = Collections.MakeMap (STyOrd)
+
+(* TODO: is it possible not to have this? *)
+module type TypeWrapper = sig
+ type t
+end
+
+(* TODO: don't manage to get the syntax right so using a functor *)
+module MakeNormalizer
+ (R : TypeWrapper)
+ (RTyMap : Collections.Map with type key = R.t T.region T.ty)
+ (M : Collections.Map with type key = R.t T.region C.trait_type_ref) =
+struct
+ let compute_norm_trait_types_from_preds
+ (trait_type_constraints : R.t T.region T.trait_type_constraint list) :
+ R.t T.region T.ty M.t =
+ (* Compute a union-find structure by recursively exploring the predicates and clauses *)
+ let norm : R.t T.region T.ty UF.elem RTyMap.t ref = ref RTyMap.empty in
+ let get_ref (ty : R.t T.region T.ty) : R.t T.region T.ty UF.elem =
+ match RTyMap.find_opt ty !norm with
+ | Some r -> r
+ | None ->
+ let r = UF.make ty in
+ norm := RTyMap.add ty r !norm;
+ r
+ in
+ let add_trait_type_constraint (c : R.t T.region T.trait_type_constraint) =
+ let trait_ty = T.TraitType (c.trait_ref, c.generics, c.type_name) in
+ let trait_ty_ref = get_ref trait_ty in
+ let ty_ref = get_ref c.ty in
+ let new_repr = UF.get ty_ref in
+ let merged = UF.union trait_ty_ref ty_ref in
+ (* Not sure the set operation is necessary, but I want to control which
+ representative is chosen *)
+ UF.set merged new_repr
+ in
+ (* Explore the local predicates *)
+ List.iter add_trait_type_constraint trait_type_constraints;
+ (* TODO: explore the local clauses *)
+ (* Compute the norm maps *)
+ let rbindings =
+ List.map (fun (k, v) -> (k, UF.get v)) (RTyMap.bindings !norm)
+ in
+ (* Filter the keys to keep only the trait type aliases *)
+ let rbindings =
+ List.filter_map
+ (fun (k, v) ->
+ match k with
+ | T.TraitType (trait_ref, generics, type_name) ->
+ assert (generics = TypesUtils.mk_empty_generic_args);
+ Some ({ C.trait_ref; type_name }, v)
+ | _ -> None)
+ rbindings
+ in
+ M.of_list rbindings
+end
+
+(** Compute the representative classes of trait associated types, for normalization *)
+let compute_norm_trait_stypes_from_preds
+ (trait_type_constraints : T.strait_type_constraint list) :
+ T.sty C.STraitTypeRefMap.t =
+ (* Compute the normalization map for the types with regions *)
+ let module R = struct
+ type t = T.region_var_id
+ end in
+ let module M = C.STraitTypeRefMap in
+ let module Norm = MakeNormalizer (R) (STyMap) (M) in
+ Norm.compute_norm_trait_types_from_preds trait_type_constraints
+
+(** Compute the representative classes of trait associated types, for normalization *)
+let compute_norm_trait_types_from_preds
+ (trait_type_constraints : T.rtrait_type_constraint list) :
+ T.ety C.ETraitTypeRefMap.t * T.rty C.RTraitTypeRefMap.t =
+ (* Compute the normalization map for the types with regions *)
+ let module R = struct
+ type t = T.region_id
+ end in
+ let module M = C.RTraitTypeRefMap in
+ let module Norm = MakeNormalizer (R) (RTyMap) (M) in
+ let rbindings =
+ Norm.compute_norm_trait_types_from_preds trait_type_constraints
+ in
+ (* Compute the normalization map for the types with erased regions *)
+ let ebindings =
+ List.map
+ (fun (k, v) ->
+ ( trait_type_ref_substitute Subst.erase_regions_subst k,
+ Subst.erase_regions v ))
+ (M.bindings rbindings)
+ in
+ (C.ETraitTypeRefMap.of_list ebindings, rbindings)
+
+let ctx_add_norm_trait_stypes_from_preds (ctx : C.eval_ctx)
+ (trait_type_constraints : T.strait_type_constraint list) : C.eval_ctx =
+ let norm_trait_stypes =
+ compute_norm_trait_stypes_from_preds trait_type_constraints
+ in
+ { ctx with C.norm_trait_stypes }
+
+let ctx_add_norm_trait_types_from_preds (ctx : C.eval_ctx)
+ (trait_type_constraints : T.rtrait_type_constraint list) : C.eval_ctx =
+ let norm_trait_etypes, norm_trait_rtypes =
+ compute_norm_trait_types_from_preds trait_type_constraints
+ in
+ { ctx with C.norm_trait_etypes; norm_trait_rtypes }
+
+(** A trait instance id refers to a local clause if it only uses the variants:
+ [Self], [Clause], [ParentClause], [ItemClause] *)
+let rec trait_instance_id_is_local_clause (id : 'r T.trait_instance_id) : bool =
+ match id with
+ | T.Self | Clause _ -> true
+ | TraitImpl _ | BuiltinOrAuto _ | TraitRef _ | UnknownTrait _ | FnPointer _ ->
+ false
+ | ParentClause (id, _, _) | ItemClause (id, _, _, _) ->
+ trait_instance_id_is_local_clause id
+
+(** About the conversion functions: for now we need them (TODO: merge ety, rty, etc.),
+ but they should be applied to types without regions.
+ *)
+type 'r norm_ctx = {
+ ctx : C.eval_ctx;
+ get_ty_repr : 'r C.trait_type_ref -> 'r T.ty option;
+ convert_ety : T.ety -> 'r T.ty; (* TODO: remove? *)
+ convert_etrait_ref : T.etrait_ref -> 'r T.trait_ref; (* TODO: remove? *)
+ ty_to_string : 'r T.ty -> string;
+ generic_params_to_string : T.generic_params -> string;
+ generic_args_to_string : 'r T.generic_args -> string;
+ trait_ref_to_string : 'r T.trait_ref -> string;
+ trait_instance_id_to_string : 'r T.trait_instance_id -> string;
+ pp_r : Format.formatter -> 'r -> unit;
+}
+
+(** Small utility to lookup trait impls, together with a substitution.
+
+ Remark: one reason we have those small helpers is that all functions are
+ parameterized by a type variable 'r. The OCaml type inferencer and type
+ checker are however not very good at generating precise error messages in
+ this context: if in the body of the function we have an overly constrained
+ usage of 'r (for instance, the type inferencer deduces 'r should be
+ [T.erased_region]), it will not be able to pinpoint the location which
+ introduced the constraints and we just get a type-checking error for the
+ whole function. The fact that we have mutually recursive functions makes it
+ worse (the type-checker sometimes indicates a well-typed function as not
+ well-typed, because it calls a not well-typed function...).
+ By isolating the places where such errors typically happen in small helpers
+ (i.e., the places where we convert between different types of regions by
+ performing substitutions), we make maintenance a lot easier.
+ *)
+let ctx_lookup_trait_impl :
+ 'r.
+ 'r norm_ctx ->
+ T.TraitImplId.id ->
+ 'r T.generic_args ->
+ A.trait_impl * (T.region_var_id T.region, 'r) Subst.subst =
+ fun ctx impl_id generics ->
+ (* Lookup the implementation *)
+ let trait_impl = C.ctx_lookup_trait_impl ctx.ctx impl_id in
+ (* The substitution *)
+ let tr_self = T.UnknownTrait __FUNCTION__ in
+ let subst =
+ Subst.make_subst_from_generics_no_regions trait_impl.generics generics
+ tr_self
+ in
+ (* Return *)
+ (trait_impl, subst)
+
+let ctx_lookup_trait_impl_ty :
+ 'r.
+ 'r norm_ctx -> T.TraitImplId.id -> 'r T.generic_args -> string -> 'r T.ty
+ =
+ fun ctx impl_id generics type_name ->
+ (* Lookup the implementation *)
+ let trait_impl, subst = ctx_lookup_trait_impl ctx impl_id generics in
+ (* Lookup the type *)
+ let ty = snd (List.assoc type_name trait_impl.types) in
+ (* Annoying: convert etype to an stype - TODO: how to avoid that? *)
+ let ty : T.sty = TypesUtils.ety_no_regions_to_gr_ty ty in
+ (* Substitute *)
+ Subst.ty_substitute subst ty
+
+let ctx_lookup_trait_impl_parent_clause :
+ 'r.
+ 'r norm_ctx ->
+ T.TraitImplId.id ->
+ 'r T.generic_args ->
+ T.TraitClauseId.id ->
+ 'r T.trait_ref =
+ fun ctx impl_id generics clause_id ->
+ (* Lookup the implementation *)
+ let trait_impl, subst = ctx_lookup_trait_impl ctx impl_id generics in
+ (* Lookup the clause *)
+ let clause = T.TraitClauseId.nth trait_impl.parent_trait_refs clause_id in
+ (* Sanity check: the clause necessarily refers to an impl *)
+ let _ = TypesUtils.trait_instance_id_as_trait_impl clause.trait_id in
+ (* Substitute *)
+ Subst.trait_ref_substitute subst clause
+
+let ctx_lookup_trait_impl_item_clause :
+ 'r.
+ 'r norm_ctx ->
+ T.TraitImplId.id ->
+ 'r T.generic_args ->
+ string ->
+ T.TraitClauseId.id ->
+ 'r T.trait_ref =
+ fun ctx impl_id generics item_name clause_id ->
+ (* Lookup the implementation *)
+ let trait_impl, subst = ctx_lookup_trait_impl ctx impl_id generics in
+ (* Lookup the item then its clause *)
+ let item = List.assoc item_name trait_impl.types in
+ let clause = T.TraitClauseId.nth (fst item) clause_id in
+ (* Sanity check: the clause necessarily refers to an impl *)
+ let _ = TypesUtils.trait_instance_id_as_trait_impl clause.trait_id in
+ (* Annoying: convert etype to an stype - TODO: how to avoid that? *)
+ let clause : T.strait_ref =
+ TypesUtils.etrait_ref_no_regions_to_gr_trait_ref clause
+ in
+ (* Substitute *)
+ Subst.trait_ref_substitute subst clause
+
+(** Normalize a type by simplifying the references to trait associated types
+ and choosing a representative when there are equalities between types
+ enforced by local clauses (i.e., `where Trait1::T = Trait2::U`.
+
+ See the comments for {!ctx_normalize_trait_instance_id}.
+ *)
+let rec ctx_normalize_ty : 'r. 'r norm_ctx -> 'r T.ty -> 'r T.ty =
+ fun ctx ty ->
+ log#ldebug (lazy ("ctx_normalize_ty: " ^ ctx.ty_to_string ty));
+ match ty with
+ | T.Adt (id, generics) -> Adt (id, ctx_normalize_generic_args ctx generics)
+ | TypeVar _ | Literal _ | Never -> ty
+ | Ref (r, ty, rkind) ->
+ let ty = ctx_normalize_ty ctx ty in
+ T.Ref (r, ty, rkind)
+ | RawPtr (ty, rkind) ->
+ let ty = ctx_normalize_ty ctx ty in
+ RawPtr (ty, rkind)
+ | Arrow (inputs, output) ->
+ let inputs = List.map (ctx_normalize_ty ctx) inputs in
+ let output = ctx_normalize_ty ctx output in
+ Arrow (inputs, output)
+ | TraitType (trait_ref, generics, type_name) -> (
+ log#ldebug
+ (lazy
+ ("ctx_normalize_ty:\n- trait type: " ^ ctx.ty_to_string ty
+ ^ "\n- trait_ref: "
+ ^ ctx.trait_ref_to_string trait_ref
+ ^ "\n- raw trait ref:\n"
+ ^ T.show_trait_ref ctx.pp_r trait_ref
+ ^ "\n- generics:\n"
+ ^ ctx.generic_args_to_string generics));
+ (* Normalize and attempt to project the type from the trait ref *)
+ let trait_ref = ctx_normalize_trait_ref ctx trait_ref in
+ let generics = ctx_normalize_generic_args ctx generics in
+ (* For now, we don't support higher order types *)
+ assert (generics = TypesUtils.mk_empty_generic_args);
+ let ty : 'r T.ty =
+ match trait_ref.trait_id with
+ | T.TraitRef
+ { T.trait_id = T.TraitImpl impl_id; generics = ref_generics; _ } ->
+ assert (ref_generics = TypesUtils.mk_empty_generic_args);
+ log#ldebug
+ (lazy
+ ("ctx_normalize_ty: trait type: trait ref: "
+ ^ ctx.ty_to_string ty));
+ (* Lookup the type *)
+ let ty =
+ ctx_lookup_trait_impl_ty ctx impl_id trait_ref.generics type_name
+ in
+ (* Normalize *)
+ ctx_normalize_ty ctx ty
+ | T.TraitImpl impl_id ->
+ log#ldebug
+ (lazy
+ ("ctx_normalize_ty (trait impl):\n- trait type: "
+ ^ ctx.ty_to_string ty ^ "\n- trait_ref: "
+ ^ ctx.trait_ref_to_string trait_ref
+ ^ "\n- raw trait ref:\n"
+ ^ T.show_trait_ref ctx.pp_r trait_ref));
+ (* This happens. This doesn't come from the substitutions
+ performed by Aeneas (the [TraitImpl] would be wrapped in a
+ [TraitRef] but from non-normalized traits translated from
+ the Rustc AST.
+ TODO: factor out with the branch above.
+ *)
+ (* Lookup the type *)
+ let ty =
+ ctx_lookup_trait_impl_ty ctx impl_id trait_ref.generics type_name
+ in
+ (* Normalize *)
+ ctx_normalize_ty ctx ty
+ | _ ->
+ log#ldebug
+ (lazy
+ ("ctx_normalize_ty: trait type: not a trait ref: "
+ ^ ctx.ty_to_string ty ^ "\n- trait_ref: "
+ ^ ctx.trait_ref_to_string trait_ref
+ ^ "\n- raw trait ref:\n"
+ ^ T.show_trait_ref ctx.pp_r trait_ref));
+ (* We can't project *)
+ assert (trait_instance_id_is_local_clause trait_ref.trait_id);
+ T.TraitType (trait_ref, generics, type_name)
+ in
+ let tr : 'r C.trait_type_ref = { C.trait_ref; type_name } in
+ (* Lookup the representative, if there is *)
+ match ctx.get_ty_repr tr with None -> ty | Some ty -> ty)
+
+(** This returns the normalized trait instance id together with an optional
+ reference to a trait **implementation** (the `trait_ref` we return has
+ necessarily for instance id a [TraitImpl]).
+
+ We need this in particular to simplify the trait instance ids after we
+ performed a substitution.
+
+ Example:
+ ========
+ {[
+ trait Trait {
+ type S
+ }
+
+ impl TraitImpl for Foo {
+ type S = usize
+ }
+
+ fn f<T : Trait>(...) -> T::S;
+
+ ...
+ let x = f<Foo>[TraitImpl](...);
+ (* The return type of the call to f is:
+ T::S ~~> TraitImpl::S ~~> usize
+ *)
+ ]}
+
+ Several remarks:
+ - as we do not allow higher-order types (yet) then local clauses (and
+ sub-clauses) can't have generic arguments
+ - the [TraitRef] case only happens because of substitution, the role of
+ the normalization is in particular to eliminate it. Inside a [TraitRef]
+ there is necessarily:
+ - an id referencing a local (sub-)clause, that is an id using the variants
+ [Self], [Clause], [ItemClause] and [ParentClause] exclusively. We can't
+ simplify those cases: all we can do is remove the [TraitRef] wrapper
+ by leveraging the fact that the generic arguments must be empty.
+ - a [TraitImpl]. Note that the [TraitImpl] is necessarily just a [TraitImpl],
+ it can't be for instance a [ParentClause(TraitImpl ...)] because the
+ trait resolution would then directly reference the implementation
+ designated by [ParentClause(TraitImpl ...)] (and same for the other cases).
+ In this case we can lookup the trait implementation and recursively project
+ over it.
+ *)
+and ctx_normalize_trait_instance_id :
+ 'r.
+ 'r norm_ctx ->
+ 'r T.trait_instance_id ->
+ 'r T.trait_instance_id * 'r T.trait_ref option =
+ fun ctx id ->
+ match id with
+ | Self -> (id, None)
+ | TraitImpl _ ->
+ (* The [TraitImpl] shouldn't be inside any projection - we check this
+ elsewhere by asserting that whenever we return [None] for the impl
+ trait ref, then the id actually refers to a local clause. *)
+ (id, None)
+ | Clause _ -> (id, None)
+ | BuiltinOrAuto _ -> (id, None)
+ | ParentClause (inst_id, decl_id, clause_id) -> (
+ let inst_id, impl = ctx_normalize_trait_instance_id ctx inst_id in
+ (* Check if the inst_id refers to a specific implementation, if yes project *)
+ match impl with
+ | None ->
+ (* This is actually a local clause *)
+ assert (trait_instance_id_is_local_clause inst_id);
+ (ParentClause (inst_id, decl_id, clause_id), None)
+ | Some impl ->
+ (* We figure out the parent clause by doing the following:
+ {[
+ // The implementation we are looking at
+ impl Impl1 : Trait1 { ... }
+
+ // Check the trait it implements
+ trait Trait1 : ParentTrait1 + ParentTrait2 { ... }
+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^
+ those are the parent clauses
+ ]}
+ *)
+ (* Lookup the clause *)
+ let impl_id =
+ TypesUtils.trait_instance_id_as_trait_impl impl.trait_id
+ in
+ let clause =
+ ctx_lookup_trait_impl_parent_clause ctx impl_id impl.generics
+ clause_id
+ in
+ (* Normalize the clause *)
+ let clause = ctx_normalize_trait_ref ctx clause in
+ (TraitRef clause, Some clause))
+ | ItemClause (inst_id, decl_id, item_name, clause_id) -> (
+ let inst_id, impl = ctx_normalize_trait_instance_id ctx inst_id in
+ (* Check if the inst_id refers to a specific implementation, if yes project *)
+ match impl with
+ | None ->
+ (* This is actually a local clause *)
+ assert (trait_instance_id_is_local_clause inst_id);
+ (ItemClause (inst_id, decl_id, item_name, clause_id), None)
+ | Some impl ->
+ (* We figure out the item clause by doing the following:
+ {[
+ // The implementation we are looking at
+ impl Impl1 : Trait1<R> {
+ type S = ...
+ with Impl2 : Trait2 ... // Instances satisfying the declared bounds
+ ^^^^^^^^^^^^^^^^^^
+ Lookup the clause from here
+ }
+ ]}
+ *)
+ (* Lookup the impl *)
+ let impl_id =
+ TypesUtils.trait_instance_id_as_trait_impl impl.trait_id
+ in
+ let clause =
+ ctx_lookup_trait_impl_item_clause ctx impl_id impl.generics
+ item_name clause_id
+ in
+ (* Normalize the clause *)
+ let clause = ctx_normalize_trait_ref ctx clause in
+ (TraitRef clause, Some clause))
+ | TraitRef { T.trait_id = T.TraitImpl trait_id; generics; trait_decl_ref } ->
+ (* We can't simplify the id *yet* : we will simplify it when projecting.
+ However, we have an implementation to return *)
+ (* Normalize the generics *)
+ let generics = ctx_normalize_generic_args ctx generics in
+ let trait_decl_ref = ctx_normalize_trait_decl_ref ctx trait_decl_ref in
+ let trait_ref : 'r T.trait_ref =
+ { T.trait_id = T.TraitImpl trait_id; generics; trait_decl_ref }
+ in
+ (TraitRef trait_ref, Some trait_ref)
+ | TraitRef trait_ref ->
+ (* The trait instance id necessarily refers to a local sub-clause. We
+ can't project over it and can only peel off the [TraitRef] wrapper *)
+ assert (trait_instance_id_is_local_clause trait_ref.trait_id);
+ assert (trait_ref.generics = TypesUtils.mk_empty_generic_args);
+ (trait_ref.trait_id, None)
+ | FnPointer ty ->
+ let ty = ctx_normalize_ty ctx ty in
+ (* TODO: we might want to return the ref to the function pointer,
+ in order to later normalize a call to this function pointer *)
+ (FnPointer ty, None)
+ | UnknownTrait _ ->
+ (* This is actually an error case *)
+ (id, None)
+
+and ctx_normalize_generic_args (ctx : 'r norm_ctx)
+ (generics : 'r T.generic_args) : 'r T.generic_args =
+ let { T.regions; types; const_generics; trait_refs } = generics in
+ let types = List.map (ctx_normalize_ty ctx) types in
+ let trait_refs = List.map (ctx_normalize_trait_ref ctx) trait_refs in
+ { T.regions; types; const_generics; trait_refs }
+
+and ctx_normalize_trait_ref (ctx : 'r norm_ctx) (trait_ref : 'r T.trait_ref) :
+ 'r T.trait_ref =
+ log#ldebug
+ (lazy
+ ("ctx_normalize_trait_ref: "
+ ^ ctx.trait_ref_to_string trait_ref
+ ^ "\n- raw trait ref:\n"
+ ^ T.show_trait_ref ctx.pp_r trait_ref));
+ let { T.trait_id; generics; trait_decl_ref } = trait_ref in
+ (* Check if the id is an impl, otherwise normalize it *)
+ let trait_id, norm_trait_ref = ctx_normalize_trait_instance_id ctx trait_id in
+ match norm_trait_ref with
+ | None ->
+ log#ldebug
+ (lazy
+ ("ctx_normalize_trait_ref: no norm: "
+ ^ ctx.trait_instance_id_to_string trait_id));
+ let generics = ctx_normalize_generic_args ctx generics in
+ let trait_decl_ref = ctx_normalize_trait_decl_ref ctx trait_decl_ref in
+ { T.trait_id; generics; trait_decl_ref }
+ | Some trait_ref ->
+ log#ldebug
+ (lazy
+ ("ctx_normalize_trait_ref: normalized to: "
+ ^ ctx.trait_ref_to_string trait_ref));
+ assert (generics = TypesUtils.mk_empty_generic_args);
+ trait_ref
+
+(* Not sure this one is really necessary *)
+and ctx_normalize_trait_decl_ref (ctx : 'r norm_ctx)
+ (trait_decl_ref : 'r T.trait_decl_ref) : 'r T.trait_decl_ref =
+ let { T.trait_decl_id; decl_generics } = trait_decl_ref in
+ let decl_generics = ctx_normalize_generic_args ctx decl_generics in
+ { T.trait_decl_id; decl_generics }
+
+let ctx_normalize_trait_type_constraint (ctx : 'r norm_ctx)
+ (ttc : 'r T.trait_type_constraint) : 'r T.trait_type_constraint =
+ let { T.trait_ref; generics; type_name; ty } = ttc in
+ let trait_ref = ctx_normalize_trait_ref ctx trait_ref in
+ let generics = ctx_normalize_generic_args ctx generics in
+ let ty = ctx_normalize_ty ctx ty in
+ { T.trait_ref; generics; type_name; ty }
+
+let generic_params_to_string ctx x =
+ "<" ^ String.concat ", " (fst (PA.generic_params_to_strings ctx x)) ^ ">"
+
+let mk_snorm_ctx (ctx : C.eval_ctx) : T.RegionVarId.id T.region norm_ctx =
+ let get_ty_repr x = C.STraitTypeRefMap.find_opt x ctx.norm_trait_stypes in
+ {
+ ctx;
+ get_ty_repr;
+ convert_ety = TypesUtils.ety_no_regions_to_sty;
+ convert_etrait_ref = TypesUtils.etrait_ref_no_regions_to_gr_trait_ref;
+ ty_to_string = PA.sty_to_string ctx;
+ generic_params_to_string = generic_params_to_string ctx;
+ generic_args_to_string = PA.sgeneric_args_to_string ctx;
+ trait_ref_to_string = PA.strait_ref_to_string ctx;
+ trait_instance_id_to_string = PA.strait_instance_id_to_string ctx;
+ pp_r = T.pp_region T.pp_region_var_id;
+ }
+
+let mk_rnorm_ctx (ctx : C.eval_ctx) : T.RegionId.id T.region norm_ctx =
+ let get_ty_repr x = C.RTraitTypeRefMap.find_opt x ctx.norm_trait_rtypes in
+ {
+ ctx;
+ get_ty_repr;
+ convert_ety = TypesUtils.ety_no_regions_to_rty;
+ convert_etrait_ref = TypesUtils.etrait_ref_no_regions_to_gr_trait_ref;
+ ty_to_string = PA.rty_to_string ctx;
+ generic_params_to_string = generic_params_to_string ctx;
+ generic_args_to_string = PA.rgeneric_args_to_string ctx;
+ trait_ref_to_string = PA.rtrait_ref_to_string ctx;
+ trait_instance_id_to_string = PA.rtrait_instance_id_to_string ctx;
+ pp_r = T.pp_region T.pp_region_id;
+ }
+
+let mk_enorm_ctx (ctx : C.eval_ctx) : T.erased_region norm_ctx =
+ let get_ty_repr x = C.ETraitTypeRefMap.find_opt x ctx.norm_trait_etypes in
+ {
+ ctx;
+ get_ty_repr;
+ convert_ety = (fun x -> x);
+ convert_etrait_ref = (fun x -> x);
+ ty_to_string = PA.ety_to_string ctx;
+ generic_params_to_string = generic_params_to_string ctx;
+ generic_args_to_string = PA.egeneric_args_to_string ctx;
+ trait_ref_to_string = PA.etrait_ref_to_string ctx;
+ trait_instance_id_to_string = PA.etrait_instance_id_to_string ctx;
+ pp_r = T.pp_erased_region;
+ }
+
+let ctx_normalize_sty (ctx : C.eval_ctx) (ty : T.sty) : T.sty =
+ ctx_normalize_ty (mk_snorm_ctx ctx) ty
+
+let ctx_normalize_rty (ctx : C.eval_ctx) (ty : T.rty) : T.rty =
+ ctx_normalize_ty (mk_rnorm_ctx ctx) ty
+
+let ctx_normalize_ety (ctx : C.eval_ctx) (ty : T.ety) : T.ety =
+ ctx_normalize_ty (mk_enorm_ctx ctx) ty
+
+let ctx_normalize_rtrait_type_constraint (ctx : C.eval_ctx)
+ (ttc : T.rtrait_type_constraint) : T.rtrait_type_constraint =
+ ctx_normalize_trait_type_constraint (mk_rnorm_ctx ctx) ttc
+
+(** Same as [type_decl_get_instantiated_variants_fields_rtypes] but normalizes the types *)
+let type_decl_get_inst_norm_variants_fields_rtypes (ctx : C.eval_ctx)
+ (def : T.type_decl) (generics : T.rgeneric_args) :
+ (T.VariantId.id option * T.rty list) list =
+ let res =
+ Subst.type_decl_get_instantiated_variants_fields_rtypes def generics
+ in
+ List.map
+ (fun (variant_id, types) ->
+ (variant_id, List.map (ctx_normalize_rty ctx) types))
+ res
+
+(** Same as [type_decl_get_instantiated_field_rtypes] but normalizes the types *)
+let type_decl_get_inst_norm_field_rtypes (ctx : C.eval_ctx) (def : T.type_decl)
+ (opt_variant_id : T.VariantId.id option) (generics : T.rgeneric_args) :
+ T.rty list =
+ let types =
+ Subst.type_decl_get_instantiated_field_rtypes def opt_variant_id generics
+ in
+ List.map (ctx_normalize_rty ctx) types
+
+(** Same as [ctx_adt_value_get_instantiated_field_rtypes] but normalizes the types *)
+let ctx_adt_value_get_inst_norm_field_rtypes (ctx : C.eval_ctx)
+ (adt : V.adt_value) (id : T.type_id) (generics : T.rgeneric_args) :
+ T.rty list =
+ let types =
+ Subst.ctx_adt_value_get_instantiated_field_rtypes ctx adt id generics
+ in
+ List.map (ctx_normalize_rty ctx) types
+
+(** Same as [ctx_adt_value_get_instantiated_field_etypes] but normalizes the types *)
+let type_decl_get_inst_norm_field_etypes (ctx : C.eval_ctx) (def : T.type_decl)
+ (opt_variant_id : T.VariantId.id option) (generics : T.egeneric_args) :
+ T.ety list =
+ let types =
+ Subst.type_decl_get_instantiated_field_etypes def opt_variant_id generics
+ in
+ List.map (ctx_normalize_ety ctx) types
+
+(** Same as [ctx_adt_get_instantiated_field_etypes] but normalizes the types *)
+let ctx_adt_get_inst_norm_field_etypes (ctx : C.eval_ctx)
+ (def_id : T.TypeDeclId.id) (opt_variant_id : T.VariantId.id option)
+ (generics : T.egeneric_args) : T.ety list =
+ let types =
+ Subst.ctx_adt_get_instantiated_field_etypes ctx def_id opt_variant_id
+ generics
+ in
+ List.map (ctx_normalize_ety ctx) types
+
+(** Same as [substitute_signature] but normalizes the types *)
+let ctx_subst_norm_signature (ctx : C.eval_ctx)
+ (asubst : T.RegionGroupId.id -> V.AbstractionId.id)
+ (r_subst : T.RegionVarId.id -> T.RegionId.id)
+ (ty_subst : T.TypeVarId.id -> T.rty)
+ (cg_subst : T.ConstGenericVarId.id -> T.const_generic)
+ (tr_subst : T.TraitClauseId.id -> T.rtrait_instance_id)
+ (tr_self : T.rtrait_instance_id) (sg : A.fun_sig) : A.inst_fun_sig =
+ let sg =
+ Subst.substitute_signature asubst r_subst ty_subst cg_subst tr_subst tr_self
+ sg
+ in
+ let { A.regions_hierarchy; inputs; output; trait_type_constraints } = sg in
+ let inputs = List.map (ctx_normalize_rty ctx) inputs in
+ let output = ctx_normalize_rty ctx output in
+ let trait_type_constraints =
+ List.map (ctx_normalize_rtrait_type_constraint ctx) trait_type_constraints
+ in
+ { regions_hierarchy; inputs; output; trait_type_constraints }
diff --git a/compiler/Assumed.ml b/compiler/Assumed.ml
index 11cd5666..79f6b0d4 100644
--- a/compiler/Assumed.ml
+++ b/compiler/Assumed.ml
@@ -63,200 +63,52 @@ module Sig = struct
let empty_const_generic_params : T.const_generic_var list = []
+ let mk_generic_args regions types const_generics : T.sgeneric_args =
+ { regions; types; const_generics; trait_refs = [] }
+
+ let mk_generic_params regions types const_generics : T.generic_params =
+ { regions; types; const_generics; trait_clauses = [] }
+
let mk_ref_ty (r : T.RegionVarId.id T.region) (ty : T.sty) (is_mut : bool) :
T.sty =
let ref_kind = if is_mut then T.Mut else T.Shared in
mk_ref_ty r ty ref_kind
let mk_array_ty (ty : T.sty) (cg : T.const_generic) : T.sty =
- Adt (Assumed Array, [], [ ty ], [ cg ])
+ Adt (Assumed Array, mk_generic_args [] [ ty ] [ cg ])
- let mk_slice_ty (ty : T.sty) : T.sty = Adt (Assumed Slice, [], [ ty ], [])
- let range_ty : T.sty = Adt (Assumed Range, [], [ usize_ty ], [])
+ let mk_slice_ty (ty : T.sty) : T.sty =
+ Adt (Assumed Slice, mk_generic_args [] [ ty ] [])
- (** [fn<T>(&'a mut T, T) -> T] *)
- let mem_replace_sig : A.fun_sig =
- (* The signature fields *)
- let region_params = [ region_param_0 ] (* <'a> *) in
- let regions_hierarchy = [ region_group_0 ] (* [{<'a>}] *) in
- let type_params = [ type_param_0 ] (* <T> *) in
- let inputs =
- [ mk_ref_ty rvar_0 tvar_0 true (* &'a mut T *); tvar_0 (* T *) ]
+ let mk_sig generics regions_hierarchy inputs output : A.fun_sig =
+ let preds : T.predicates =
+ { regions_outlive = []; types_outlive = []; trait_type_constraints = [] }
in
- let output = tvar_0 (* T *) in
{
- region_params;
- num_early_bound_regions = 0;
+ is_unsafe = false;
+ generics;
+ preds;
+ parent_params_info = None;
regions_hierarchy;
- type_params;
- const_generic_params = empty_const_generic_params;
inputs;
output;
}
(** [fn<T>(T) -> Box<T>] *)
let box_new_sig : A.fun_sig =
- {
- region_params = [];
- num_early_bound_regions = 0;
- regions_hierarchy = [];
- type_params = [ type_param_0 ] (* <T> *);
- const_generic_params = empty_const_generic_params;
- inputs = [ tvar_0 (* T *) ];
- output = mk_box_ty tvar_0 (* Box<T> *);
- }
+ let generics = mk_generic_params [] [ type_param_0 ] [] (* <T> *) in
+ let regions_hierarchy = [] in
+ let inputs = [ tvar_0 (* T *) ] in
+ let output = mk_box_ty tvar_0 (* Box<T> *) in
+ mk_sig generics regions_hierarchy inputs output
(** [fn<T>(Box<T>) -> ()] *)
let box_free_sig : A.fun_sig =
- {
- region_params = [];
- num_early_bound_regions = 0;
- regions_hierarchy = [];
- type_params = [ type_param_0 ] (* <T> *);
- const_generic_params = empty_const_generic_params;
- inputs = [ mk_box_ty tvar_0 (* Box<T> *) ];
- output = mk_unit_ty (* () *);
- }
-
- (** Helper for [Box::deref_shared] and [Box::deref_mut].
- Returns:
- [fn<'a, T>(&'a (mut) Box<T>) -> &'a (mut) T]
- *)
- let box_deref_gen_sig (is_mut : bool) : A.fun_sig =
- (* The signature fields *)
- let region_params = [ region_param_0 ] in
- let regions_hierarchy = [ region_group_0 ] (* <'a> *) in
- {
- region_params;
- num_early_bound_regions = 0;
- regions_hierarchy;
- type_params = [ type_param_0 ] (* <T> *);
- const_generic_params = empty_const_generic_params;
- inputs =
- [ mk_ref_ty rvar_0 (mk_box_ty tvar_0) is_mut (* &'a (mut) Box<T> *) ];
- output = mk_ref_ty rvar_0 tvar_0 is_mut (* &'a (mut) T *);
- }
-
- (** [fn<'a, T>(&'a Box<T>) -> &'a T] *)
- let box_deref_shared_sig = box_deref_gen_sig false
-
- (** [fn<'a, T>(&'a mut Box<T>) -> &'a mut T] *)
- let box_deref_mut_sig = box_deref_gen_sig true
-
- (** [fn<T>() -> Vec<T>] *)
- let vec_new_sig : A.fun_sig =
- let region_params = [] in
+ let generics = mk_generic_params [] [ type_param_0 ] [] (* <T> *) in
let regions_hierarchy = [] in
- let type_params = [ type_param_0 ] (* <T> *) in
- let inputs = [] in
- let output = mk_vec_ty tvar_0 (* Vec<T> *) in
- {
- region_params;
- num_early_bound_regions = 0;
- regions_hierarchy;
- type_params;
- const_generic_params = empty_const_generic_params;
- inputs;
- output;
- }
-
- (** [fn<T>(&'a mut Vec<T>, T)] *)
- let vec_push_sig : A.fun_sig =
- (* The signature fields *)
- let region_params = [ region_param_0 ] in
- let regions_hierarchy = [ region_group_0 ] (* <'a> *) in
- let type_params = [ type_param_0 ] (* <T> *) in
- let inputs =
- [
- mk_ref_ty rvar_0 (mk_vec_ty tvar_0) true (* &'a mut Vec<T> *);
- tvar_0 (* T *);
- ]
- in
+ let inputs = [ mk_box_ty tvar_0 (* Box<T> *) ] in
let output = mk_unit_ty (* () *) in
- {
- region_params;
- num_early_bound_regions = 0;
- regions_hierarchy;
- type_params;
- const_generic_params = empty_const_generic_params;
- inputs;
- output;
- }
-
- (** [fn<T>(&'a mut Vec<T>, usize, T)] *)
- let vec_insert_sig : A.fun_sig =
- (* The signature fields *)
- let region_params = [ region_param_0 ] in
- let regions_hierarchy = [ region_group_0 ] (* <'a> *) in
- let type_params = [ type_param_0 ] (* <T> *) in
- let inputs =
- [
- mk_ref_ty rvar_0 (mk_vec_ty tvar_0) true (* &'a mut Vec<T> *);
- mk_usize_ty (* usize *);
- tvar_0 (* T *);
- ]
- in
- let output = mk_unit_ty (* () *) in
- {
- region_params;
- num_early_bound_regions = 0;
- regions_hierarchy;
- type_params;
- const_generic_params = empty_const_generic_params;
- inputs;
- output;
- }
-
- (** [fn<T>(&'a Vec<T>) -> usize] *)
- let vec_len_sig : A.fun_sig =
- (* The signature fields *)
- let region_params = [ region_param_0 ] in
- let regions_hierarchy = [ region_group_0 ] (* <'a> *) in
- let type_params = [ type_param_0 ] (* <T> *) in
- let inputs =
- [ mk_ref_ty rvar_0 (mk_vec_ty tvar_0) false (* &'a Vec<T> *) ]
- in
- let output = mk_usize_ty (* usize *) in
- {
- region_params;
- num_early_bound_regions = 0;
- regions_hierarchy;
- type_params;
- const_generic_params = empty_const_generic_params;
- inputs;
- output;
- }
-
- (** Helper:
- [fn<T>(&'a (mut) Vec<T>, usize) -> &'a (mut) T]
- *)
- let vec_index_gen_sig (is_mut : bool) : A.fun_sig =
- (* The signature fields *)
- let region_params = [ region_param_0 ] in
- let regions_hierarchy = [ region_group_0 ] (* <'a> *) in
- let type_params = [ type_param_0 ] (* <T> *) in
- let inputs =
- [
- mk_ref_ty rvar_0 (mk_vec_ty tvar_0) is_mut (* &'a (mut) Vec<T> *);
- mk_usize_ty (* usize *);
- ]
- in
- let output = mk_ref_ty rvar_0 tvar_0 is_mut (* &'a (mut) T *) in
- {
- region_params;
- num_early_bound_regions = 0;
- regions_hierarchy;
- type_params;
- const_generic_params = empty_const_generic_params;
- inputs;
- output;
- }
-
- (** [fn<T>(&'a Vec<T>, usize) -> &'a T] *)
- let vec_index_shared_sig : A.fun_sig = vec_index_gen_sig false
-
- (** [fn<T>(&'a mut Vec<T>, usize) -> &'a mut T] *)
- let vec_index_mut_sig : A.fun_sig = vec_index_gen_sig true
+ mk_sig generics regions_hierarchy inputs output
(** Array/slice functions *)
@@ -275,10 +127,10 @@ module Sig = struct
let mk_array_slice_borrow_sig (cgs : T.const_generic_var list)
(input_ty : T.TypeVarId.id -> T.sty) (index_ty : T.sty option)
(output_ty : T.TypeVarId.id -> T.sty) (is_mut : bool) : A.fun_sig =
- (* The signature fields *)
- let region_params = [ region_param_0 ] in
+ let generics =
+ mk_generic_params [ region_param_0 ] [ type_param_0 ] cgs (* <'a, T> *)
+ in
let regions_hierarchy = [ region_group_0 ] (* <'a> *) in
- let type_params = [ type_param_0 ] (* <T> *) in
let inputs =
[
mk_ref_ty rvar_0
@@ -294,15 +146,7 @@ module Sig = struct
(output_ty type_param_0.index)
is_mut (* &'a (mut) output_ty<T> *)
in
- {
- region_params;
- num_early_bound_regions = 0;
- regions_hierarchy;
- type_params;
- const_generic_params = cgs;
- inputs;
- output;
- }
+ mk_sig generics regions_hierarchy inputs output
let mk_array_slice_index_sig (is_array : bool) (is_mut : bool) : A.fun_sig =
(* Array<T, N> *)
@@ -328,50 +172,53 @@ module Sig = struct
let cgs = [ cg_param_0 ] in
mk_array_slice_borrow_sig cgs input_ty None output_ty is_mut
- let mk_array_slice_subslice_sig (is_array : bool) (is_mut : bool) : A.fun_sig
- =
- (* Array<T, N> *)
- let input_ty id =
- if is_array then mk_array_ty (T.TypeVar id) cgvar_0
- else mk_slice_ty (T.TypeVar id)
+ let array_repeat_sig =
+ let generics =
+ (* <T, N> *)
+ mk_generic_params [] [ type_param_0 ] [ cg_param_0 ]
in
- (* Range *)
- let index_ty = range_ty in
- (* Slice<T> *)
- let output_ty id = mk_slice_ty (T.TypeVar id) in
- let cgs = if is_array then [ cg_param_0 ] else [] in
- mk_array_slice_borrow_sig cgs input_ty (Some index_ty) output_ty is_mut
-
- let array_subslice_sig (is_mut : bool) =
- mk_array_slice_subslice_sig true is_mut
-
- let slice_subslice_sig (is_mut : bool) =
- mk_array_slice_subslice_sig false is_mut
+ let regions_hierarchy = [] (* <> *) in
+ let inputs = [ tvar_0 (* T *) ] in
+ let output =
+ (* [T; N] *)
+ mk_array_ty tvar_0 cgvar_0
+ in
+ mk_sig generics regions_hierarchy inputs output
(** Helper:
[fn<T>(&'a [T]) -> usize]
*)
let slice_len_sig : A.fun_sig =
- (* The signature fields *)
- let region_params = [ region_param_0 ] in
+ let generics =
+ mk_generic_params [ region_param_0 ] [ type_param_0 ] [] (* <'a, T> *)
+ in
let regions_hierarchy = [ region_group_0 ] (* <'a> *) in
- let type_params = [ type_param_0 ] (* <T> *) in
let inputs =
[ mk_ref_ty rvar_0 (mk_slice_ty tvar_0) false (* &'a [T] *) ]
in
let output = mk_usize_ty (* usize *) in
- {
- region_params;
- num_early_bound_regions = 0;
- regions_hierarchy;
- type_params;
- const_generic_params = empty_const_generic_params;
- inputs;
- output;
- }
+ mk_sig generics regions_hierarchy inputs output
end
-type assumed_info = A.assumed_fun_id * A.fun_sig * bool * name
+type raw_assumed_fun_info =
+ A.assumed_fun_id * A.fun_sig * bool * name * bool list option
+
+type assumed_fun_info = {
+ fun_id : A.assumed_fun_id;
+ fun_sig : A.fun_sig;
+ can_fail : bool;
+ name : name;
+ keep_types : bool list option;
+ (** We may want to filter some type arguments.
+
+ For instance, all the `Vec` functions (and the `Vec` type itself) take
+ an `Allocator` type as argument, that we ignore.
+ *)
+}
+
+let mk_assumed_fun_info (raw : raw_assumed_fun_info) : assumed_fun_info =
+ let fun_id, fun_sig, can_fail, name, keep_types = raw in
+ { fun_id; fun_sig; can_fail; name; keep_types }
(** The list of assumed functions and all their information:
- their signature
@@ -384,94 +231,72 @@ type assumed_info = A.assumed_fun_id * A.fun_sig * bool * name
a [usize], we have to make sure that vectors are bounded by the max usize.
As a consequence, [Vec::push] is monadic.
*)
-let assumed_infos : assumed_info list =
- let deref_pre = [ "core"; "ops"; "deref" ] in
- let vec_pre = [ "alloc"; "vec"; "Vec" ] in
- let index_pre = [ "core"; "ops"; "index" ] in
+let raw_assumed_fun_infos : raw_assumed_fun_info list =
[
- (A.Replace, Sig.mem_replace_sig, false, to_name [ "core"; "mem"; "replace" ]);
- (BoxNew, Sig.box_new_sig, false, to_name [ "alloc"; "boxed"; "Box"; "new" ]);
+ ( BoxNew,
+ Sig.box_new_sig,
+ false,
+ to_name [ "alloc"; "boxed"; "Box"; "new" ],
+ Some [ true; false ] );
+ (* BoxFree shouldn't be used *)
( BoxFree,
Sig.box_free_sig,
false,
- to_name [ "alloc"; "boxed"; "Box"; "free" ] );
- ( BoxDeref,
- Sig.box_deref_shared_sig,
- false,
- to_name (deref_pre @ [ "Deref"; "deref" ]) );
- ( BoxDerefMut,
- Sig.box_deref_mut_sig,
- false,
- to_name (deref_pre @ [ "DerefMut"; "deref_mut" ]) );
- (VecNew, Sig.vec_new_sig, false, to_name (vec_pre @ [ "new" ]));
- (VecPush, Sig.vec_push_sig, true, to_name (vec_pre @ [ "push" ]));
- (VecInsert, Sig.vec_insert_sig, true, to_name (vec_pre @ [ "insert" ]));
- (VecLen, Sig.vec_len_sig, false, to_name (vec_pre @ [ "len" ]));
- ( VecIndex,
- Sig.vec_index_shared_sig,
- true,
- to_name (index_pre @ [ "Index"; "index" ]) );
- ( VecIndexMut,
- Sig.vec_index_mut_sig,
- true,
- to_name (index_pre @ [ "IndexMut"; "index_mut" ]) );
+ to_name [ "alloc"; "boxed"; "Box"; "free" ],
+ Some [ true; false ] );
(* Array Index *)
( ArrayIndexShared,
Sig.array_index_sig false,
true,
- to_name [ "@ArrayIndexShared" ] );
- (ArrayIndexMut, Sig.array_index_sig true, true, to_name [ "@ArrayIndexMut" ]);
+ to_name [ "@ArrayIndexShared" ],
+ None );
+ ( ArrayIndexMut,
+ Sig.array_index_sig true,
+ true,
+ to_name [ "@ArrayIndexMut" ],
+ None );
(* Array to slice*)
( ArrayToSliceShared,
Sig.array_to_slice_sig false,
true,
- to_name [ "@ArrayToSliceShared" ] );
+ to_name [ "@ArrayToSliceShared" ],
+ None );
( ArrayToSliceMut,
Sig.array_to_slice_sig true,
true,
- to_name [ "@ArrayToSliceMut" ] );
- (* Array Subslice *)
- ( ArraySubsliceShared,
- Sig.array_subslice_sig false,
- true,
- to_name [ "@ArraySubsliceShared" ] );
- ( ArraySubsliceMut,
- Sig.array_subslice_sig true,
- true,
- to_name [ "@ArraySubsliceMut" ] );
+ to_name [ "@ArrayToSliceMut" ],
+ None );
+ (* Array Repeat *)
+ (ArrayRepeat, Sig.array_repeat_sig, false, to_name [ "@ArrayRepeat" ], None);
(* Slice Index *)
( SliceIndexShared,
Sig.slice_index_sig false,
true,
- to_name [ "@SliceIndexShared" ] );
- (SliceIndexMut, Sig.slice_index_sig true, true, to_name [ "@SliceIndexMut" ]);
- (* Slice Subslice *)
- ( SliceSubsliceShared,
- Sig.slice_subslice_sig false,
- true,
- to_name [ "@SliceSubsliceShared" ] );
- ( SliceSubsliceMut,
- Sig.slice_subslice_sig true,
+ to_name [ "@SliceIndexShared" ],
+ None );
+ ( SliceIndexMut,
+ Sig.slice_index_sig true,
true,
- to_name [ "@SliceSubsliceMut" ] );
- (SliceLen, Sig.slice_len_sig, false, to_name [ "@SliceLen" ]);
+ to_name [ "@SliceIndexMut" ],
+ None );
+ (SliceLen, Sig.slice_len_sig, false, to_name [ "@SliceLen" ], None);
]
-let get_assumed_info (id : A.assumed_fun_id) : assumed_info =
- match List.find_opt (fun (id', _, _, _) -> id = id') assumed_infos with
+let assumed_fun_infos : assumed_fun_info list =
+ List.map mk_assumed_fun_info raw_assumed_fun_infos
+
+let get_assumed_fun_info (id : A.assumed_fun_id) : assumed_fun_info =
+ match List.find_opt (fun x -> id = x.fun_id) assumed_fun_infos with
| Some info -> info
| None ->
raise
- (Failure ("get_assumed_info: not found: " ^ A.show_assumed_fun_id id))
+ (Failure ("get_assumed_fun_info: not found: " ^ A.show_assumed_fun_id id))
-let get_assumed_sig (id : A.assumed_fun_id) : A.fun_sig =
- let _, sg, _, _ = get_assumed_info id in
- sg
+let get_assumed_fun_sig (id : A.assumed_fun_id) : A.fun_sig =
+ (get_assumed_fun_info id).fun_sig
-let get_assumed_name (id : A.assumed_fun_id) : fun_name =
- let _, _, _, name = get_assumed_info id in
- name
+let get_assumed_fun_name (id : A.assumed_fun_id) : fun_name =
+ (get_assumed_fun_info id).name
-let assumed_can_fail (id : A.assumed_fun_id) : bool =
- let _, _, b, _ = get_assumed_info id in
- b
+let assumed_fun_can_fail (id : A.assumed_fun_id) : bool =
+ (get_assumed_fun_info id).can_fail
diff --git a/compiler/Config.ml b/compiler/Config.ml
index bd80769f..a487f9e2 100644
--- a/compiler/Config.ml
+++ b/compiler/Config.ml
@@ -124,7 +124,7 @@ let always_deconstruct_adts_with_matches = ref false
(** Controls whether we need to use a state to model the external world
(I/O, for instance).
*)
-let use_state = ref true
+let use_state = ref false
(** Controls whether we use fuel to control termination.
*)
@@ -160,7 +160,7 @@ let backward_no_state_update = ref false
files for the types, clauses and functions, or if we group them in
one file.
*)
-let split_files = ref true
+let split_files = ref false
(** Generate the library entry point, if the crate is split between different files.
@@ -306,13 +306,6 @@ let filter_useless_monadic_calls = ref true
*)
let filter_useless_functions = ref true
-(** Obsolete. TODO: remove.
-
- For Lean we used to parameterize the entire development by a section variable
- called opaque_defs, of type OpaqueDefs.
- *)
-let wrap_opaque_in_sig = ref false
-
(** Use short names for the record fields.
Some backends can't disambiguate records when their field names have collisions.
@@ -323,3 +316,23 @@ let wrap_opaque_in_sig = ref false
information), we use short names (i.e., the original field names).
*)
let record_fields_short_names = ref false
+
+(** Parameterize the traits with their associated types, so as not to use
+ types as first class objects.
+
+ This is useful for some backends with limited expressiveness like HOL4,
+ and to account for type constraints (like [fn f<T : Foo>(...) where T::bar = usize]).
+ *)
+let parameterize_trait_types = ref false
+
+(** For sanity check: type check the generated pure code (activates checks in
+ several places).
+
+ TODO: deactivated for now because we need to implement the normalization of
+ trait associated types in the pure code.
+ *)
+let type_check_pure_code = ref false
+
+(** Shall we fail hard if we encounter an issue, or should we attempt to go
+ as far as possible while leaving "holes" in the generated code? *)
+let fail_hard = ref true
diff --git a/compiler/Contexts.ml b/compiler/Contexts.ml
index 2ca5653d..dac64a9a 100644
--- a/compiler/Contexts.ml
+++ b/compiler/Contexts.ml
@@ -5,6 +5,7 @@ open LlbcAst
module V = Values
open ValuesUtils
open Identifiers
+module L = Logging
(** The [Id] module for dummy variables.
@@ -17,6 +18,9 @@ IdGen ()
type dummy_var_id = DummyVarId.id [@@deriving show, ord]
+(** The local logger *)
+let log = L.contexts_log
+
(** Some global counters.
Note that those counters were initially stored in {!eval_ctx} values,
@@ -40,6 +44,7 @@ type dummy_var_id = DummyVarId.id [@@deriving show, ord]
fn f x : fun_type =
let id = fresh_id () in
...
+ fun () -> ...
let g = f x in // <-- the fresh identifier gets generated here
let x1 = g () in // <-- no fresh generation here
@@ -250,27 +255,127 @@ type type_context = {
}
[@@deriving show]
-type fun_context = { fun_decls : fun_decl FunDeclId.Map.t } [@@deriving show]
+type fun_context = {
+ fun_decls : fun_decl FunDeclId.Map.t;
+ fun_infos : FunsAnalysis.fun_info FunDeclId.Map.t;
+}
+[@@deriving show]
type global_context = { global_decls : global_decl GlobalDeclId.Map.t }
[@@deriving show]
+type trait_decls_context = { trait_decls : trait_decl TraitDeclId.Map.t }
+[@@deriving show]
+
+type trait_impls_context = { trait_impls : trait_impl TraitImplId.Map.t }
+[@@deriving show]
+
+type decls_ctx = {
+ type_ctx : type_context;
+ fun_ctx : fun_context;
+ global_ctx : global_context;
+ trait_decls_ctx : trait_decls_context;
+ trait_impls_ctx : trait_impls_context;
+}
+[@@deriving show]
+
+(** A reference to a trait associated type *)
+type 'r trait_type_ref = { trait_ref : 'r trait_ref; type_name : string }
+[@@deriving show, ord]
+
+type etrait_type_ref = erased_region trait_type_ref [@@deriving show, ord]
+
+type rtrait_type_ref = Types.RegionId.id Types.region trait_type_ref
+[@@deriving show, ord]
+
+type strait_type_ref = Types.RegionVarId.id Types.region trait_type_ref
+[@@deriving show, ord]
+
+(* TODO: correctly use the functors so as not to have a duplication below *)
+module ETraitTypeRefOrd = struct
+ type t = etrait_type_ref
+
+ let compare = compare_etrait_type_ref
+ let to_string = show_etrait_type_ref
+ let pp_t = pp_etrait_type_ref
+ let show_t = show_etrait_type_ref
+end
+
+module RTraitTypeRefOrd = struct
+ type t = rtrait_type_ref
+
+ let compare = compare_rtrait_type_ref
+ let to_string = show_rtrait_type_ref
+ let pp_t = pp_rtrait_type_ref
+ let show_t = show_rtrait_type_ref
+end
+
+module STraitTypeRefOrd = struct
+ type t = strait_type_ref
+
+ let compare = compare_strait_type_ref
+ let to_string = show_strait_type_ref
+ let pp_t = pp_strait_type_ref
+ let show_t = show_strait_type_ref
+end
+
+module ETraitTypeRefMap = Collections.MakeMap (ETraitTypeRefOrd)
+module RTraitTypeRefMap = Collections.MakeMap (RTraitTypeRefOrd)
+module STraitTypeRefMap = Collections.MakeMap (STraitTypeRefOrd)
+
(** Evaluation context *)
type eval_ctx = {
type_context : type_context;
fun_context : fun_context;
global_context : global_context;
+ trait_decls_context : trait_decls_context;
+ trait_impls_context : trait_impls_context;
region_groups : RegionGroupId.id list;
type_vars : type_var list;
const_generic_vars : const_generic_var list;
+ const_generic_vars_map : typed_value Types.ConstGenericVarId.Map.t;
+ (** The map from const generic vars to their values. Those values
+ can be symbolic values or concrete values (in the latter case:
+ if we run in interpreter mode) *)
+ norm_trait_etypes : ety ETraitTypeRefMap.t;
+ (** The normalized trait types (a map from trait types to their representatives).
+ Note that this doesn't support account higher-order types. *)
+ norm_trait_rtypes : rty RTraitTypeRefMap.t;
+ (** We need this because we manipulate two kinds of types.
+ Note that we actually forbid regions from appearing both in the trait
+ references and in the constraints given to the associated types,
+ meaning that we don't have to worry about mismatches due to changes
+ in region ids.
+
+ TODO: how not to duplicate?
+ *)
+ norm_trait_stypes : sty STraitTypeRefMap.t;
+ (** We sometimes need to normalize types in non-instantiated signatures.
+
+ Note that we either need to use the etypes/rtypes maps, or the stypes map.
+ This means that we either compute the maps for etypes and rtypes, or compute
+ the one for stypes (we don't always compute and carry all the maps).
+ *)
env : env;
ended_regions : RegionId.Set.t;
}
[@@deriving show]
+let lookup_type_var_opt (ctx : eval_ctx) (vid : TypeVarId.id) : type_var option
+ =
+ if TypeVarId.to_int vid < List.length ctx.type_vars then
+ Some (TypeVarId.nth ctx.type_vars vid)
+ else None
+
let lookup_type_var (ctx : eval_ctx) (vid : TypeVarId.id) : type_var =
TypeVarId.nth ctx.type_vars vid
+let lookup_const_generic_var_opt (ctx : eval_ctx) (vid : ConstGenericVarId.id) :
+ const_generic_var option =
+ if ConstGenericVarId.to_int vid < List.length ctx.const_generic_vars then
+ Some (ConstGenericVarId.nth ctx.const_generic_vars vid)
+ else None
+
let lookup_const_generic_var (ctx : eval_ctx) (vid : ConstGenericVarId.id) :
const_generic_var =
ConstGenericVarId.nth ctx.const_generic_vars vid
@@ -304,6 +409,12 @@ let ctx_lookup_global_decl (ctx : eval_ctx) (gid : GlobalDeclId.id) :
global_decl =
GlobalDeclId.Map.find gid ctx.global_context.global_decls
+let ctx_lookup_trait_decl (ctx : eval_ctx) (id : TraitDeclId.id) : trait_decl =
+ TraitDeclId.Map.find id ctx.trait_decls_context.trait_decls
+
+let ctx_lookup_trait_impl (ctx : eval_ctx) (id : TraitImplId.id) : trait_impl =
+ TraitImplId.Map.find id ctx.trait_impls_context.trait_impls
+
(** Retrieve a variable's value in the current frame *)
let env_lookup_var_value (env : env) (vid : VarId.id) : typed_value =
snd (env_lookup_var env vid)
@@ -312,6 +423,11 @@ let env_lookup_var_value (env : env) (vid : VarId.id) : typed_value =
let ctx_lookup_var_value (ctx : eval_ctx) (vid : VarId.id) : typed_value =
env_lookup_var_value ctx.env vid
+(** Retrieve a const generic value in an evaluation context *)
+let ctx_lookup_const_generic_value (ctx : eval_ctx) (vid : ConstGenericVarId.id)
+ : typed_value =
+ Types.ConstGenericVarId.Map.find vid ctx.const_generic_vars_map
+
(** Update a variable's value in the current frame.
This is a helper function: it can break invariants and doesn't perform
@@ -361,6 +477,15 @@ let ctx_push_var (ctx : eval_ctx) (var : var) (v : typed_value) : eval_ctx =
*)
let ctx_push_vars (ctx : eval_ctx) (vars : (var * typed_value) list) : eval_ctx
=
+ log#ldebug
+ (lazy
+ ("push_vars:\n"
+ ^ String.concat "\n"
+ (List.map
+ (fun (var, value) ->
+ (* We can unfortunately not use Print because it depends on Contexts... *)
+ show_var var ^ " -> " ^ V.show_typed_value value)
+ vars)));
assert (
List.for_all
(fun (var, (value : typed_value)) -> var.var_ty = value.ty)
diff --git a/compiler/Driver.ml b/compiler/Driver.ml
index b646a53d..128ae890 100644
--- a/compiler/Driver.ml
+++ b/compiler/Driver.ml
@@ -17,11 +17,15 @@ let log = main_log
let _ =
(* Set up the logging - for now we use default values - TODO: use the
* command-line arguments *)
- (* By setting a level for the main_logger_handler, we filter everything *)
+ (* By setting a level for the main_logger_handler, we filter everything.
+ To have a good trace: one should switch between Info and Debug.
+ *)
Easy_logging.Handlers.set_level main_logger_handler EL.Debug;
main_log#set_level EL.Info;
llbc_of_json_logger#set_level EL.Info;
pre_passes_log#set_level EL.Info;
+ associated_types_log#set_level EL.Info;
+ contexts_log#set_level EL.Info;
interpreter_log#set_level EL.Info;
statements_log#set_level EL.Info;
loops_match_ctxs_log#set_level EL.Info;
@@ -37,7 +41,7 @@ let _ =
pure_utils_log#set_level EL.Info;
symbolic_to_pure_log#set_level EL.Info;
pure_micro_passes_log#set_level EL.Info;
- pure_to_extract_log#set_level EL.Info;
+ extract_log#set_level EL.Info;
translate_log#set_level EL.Info;
scc_log#set_level EL.Info;
reorder_decls_log#set_level EL.Info
@@ -62,6 +66,9 @@ let () =
(* Read the command line arguments *)
let dest_dir = ref "" in
+ (* Print the imported llbc *)
+ let print_llbc = ref false in
+
let spec =
[
( "-backend",
@@ -86,9 +93,9 @@ let () =
Arg.Set extract_decreases_clauses,
" Use decreases clauses/termination measures for the recursive \
definitions" );
- ( "-no-state",
- Arg.Clear use_state,
- " Do not use state-error monads, simply use error monads" );
+ ( "-state",
+ Arg.Set use_state,
+ " Use a *state*-error monads, instead of an error monads" );
( "-use-fuel",
Arg.Set use_fuel,
" Use a fuel parameter to control divergence" );
@@ -99,10 +106,10 @@ let () =
Arg.Set extract_template_decreases_clauses,
" Generate templates for the required decreases clauses/termination \
measures, in a dedicated file. Implies -decreases-clauses" );
- ( "-no-split-files",
- Arg.Clear split_files,
- " Do not split the definitions between different files for types, \
- functions, etc." );
+ ( "-split-files",
+ Arg.Set split_files,
+ " Split the definitions between different files for types, functions, \
+ etc." );
( "-no-check-inv",
Arg.Clear check_invariants,
" Deactivate the invariant sanity checks performed at every evaluation \
@@ -114,6 +121,8 @@ let () =
( "-lean-default-lakefile",
Arg.Clear lean_gen_lakefile,
" Generate a default lakefile.lean (Lean only)" );
+ ("-print-llbc", Arg.Set print_llbc, " Print the imported LLBC");
+ ("-k", Arg.Clear fail_hard, " Do not fail hard in case of error");
]
in
@@ -127,6 +136,7 @@ let () =
in
if !extract_template_decreases_clauses then extract_decreases_clauses := true;
+ if !print_llbc then main_log#set_level EL.Debug;
(* Sanity check (now that the arguments are parsed!): -template-clauses ==> decrease-clauses *)
assert (!extract_decreases_clauses || not !extract_template_decreases_clauses);
@@ -158,14 +168,14 @@ let () =
| FStar ->
(* Some patterns are not supported *)
decompose_monadic_let_bindings := false;
- decompose_nested_let_patterns := false
+ decompose_nested_let_patterns := false;
+ (* F* can disambiguate the field names *)
+ record_fields_short_names := true
| Coq ->
(* Some patterns are not supported *)
decompose_monadic_let_bindings := true;
decompose_nested_let_patterns := true
| Lean ->
- (* The Lean backend is experimental: print a warning *)
- log#lwarning (lazy "The Lean backend is experimental");
(* We don't support fuel for the Lean backend *)
if !use_fuel then (
log#error "The Lean backend doesn't support the -use-fuel option";
@@ -212,28 +222,6 @@ let () =
log#linfo (lazy ("Imported: " ^ filename));
log#ldebug (lazy ("\n" ^ Print.Crate.crate_to_string m ^ "\n"));
- (* Print a warning if the crate contains loops (loops are experimental for now) *)
- let has_loops =
- A.FunDeclId.Map.exists
- (fun _ -> Aeneas.LlbcAstUtils.fun_decl_has_loops)
- m.functions
- in
- if has_loops then log#lwarning (lazy "Support for loops is experimental");
-
- (* If we target Lean, we request the crates to be split into several files
- whenever there are opaque functions *)
- if
- !backend = Lean
- && A.FunDeclId.Map.exists
- (fun _ (d : A.fun_decl) -> d.body = None)
- m.functions
- && not !split_files
- then (
- log#error
- "For Lean, we request the -split-file option whenever using opaque \
- functions";
- fail ());
-
(* We don't support mutually recursive definitions with decreases clauses in Lean *)
if
!backend = Lean && !extract_decreases_clauses
diff --git a/compiler/Extract.ml b/compiler/Extract.ml
index c4238d83..d04f5c1d 100644
--- a/compiler/Extract.ml
+++ b/compiler/Extract.ml
@@ -3,2102 +3,104 @@
the formatter everywhere...
*)
-open Utils
open Pure
open PureUtils
open TranslateCore
open ExtractBase
-open StringUtils
open Config
-module F = Format
-
-(** Small helper to compute the name of an int type *)
-let int_name (int_ty : integer_type) =
- let isize, usize, i_format, u_format =
- match !backend with
- | FStar | Coq | HOL4 ->
- ("isize", "usize", format_of_string "i%d", format_of_string "u%d")
- | Lean -> ("Isize", "Usize", format_of_string "I%d", format_of_string "U%d")
- in
- match int_ty with
- | Isize -> isize
- | I8 -> Printf.sprintf i_format 8
- | I16 -> Printf.sprintf i_format 16
- | I32 -> Printf.sprintf i_format 32
- | I64 -> Printf.sprintf i_format 64
- | I128 -> Printf.sprintf i_format 128
- | Usize -> usize
- | U8 -> Printf.sprintf u_format 8
- | U16 -> Printf.sprintf u_format 16
- | U32 -> Printf.sprintf u_format 32
- | U64 -> Printf.sprintf u_format 64
- | U128 -> Printf.sprintf u_format 128
-
-(** Small helper to compute the name of a unary operation *)
-let unop_name (unop : unop) : string =
- match unop with
- | Not -> (
- match !backend with FStar | Lean -> "not" | Coq -> "negb" | HOL4 -> "~")
- | Neg (int_ty : integer_type) -> (
- match !backend with Lean -> "-" | _ -> int_name int_ty ^ "_neg")
- | Cast _ ->
- (* We never directly use the unop name in this case *)
- raise (Failure "Unsupported")
-
-(** Small helper to compute the name of a binary operation (note that many
- binary operations like "less than" are extracted to primitive operations,
- like [<]).
- *)
-let named_binop_name (binop : E.binop) (int_ty : integer_type) : string =
- let binop =
- match binop with
- | Div -> "div"
- | Rem -> "rem"
- | Add -> "add"
- | Sub -> "sub"
- | Mul -> "mul"
- | Lt -> "lt"
- | Le -> "le"
- | Ge -> "ge"
- | Gt -> "gt"
- | _ -> raise (Failure "Unreachable")
- in
- (* Remark: the Lean case is actually not used *)
- match !backend with
- | Lean -> int_name int_ty ^ "." ^ binop
- | FStar | Coq | HOL4 -> int_name int_ty ^ "_" ^ binop
-
-(** A list of keywords/identifiers used by the backend and with which we
- want to check collision.
-
- Remark: this is useful mostly to look for collisions when generating
- names for *variables*.
- *)
-let keywords () =
- let named_unops =
- unop_name Not
- :: List.map (fun it -> unop_name (Neg it)) T.all_signed_int_types
- in
- let named_binops = [ E.Div; Rem; Add; Sub; Mul ] in
- let named_binops =
- List.concat_map
- (fun bn -> List.map (fun it -> named_binop_name bn it) T.all_int_types)
- named_binops
- in
- let misc =
- match !backend with
- | FStar ->
- [
- "assert";
- "assert_norm";
- "assume";
- "else";
- "fun";
- "fn";
- "FStar";
- "FStar.Mul";
- "if";
- "in";
- "include";
- "int";
- "let";
- "list";
- "match";
- "not";
- "open";
- "rec";
- "scalar_cast";
- "then";
- "type";
- "Type0";
- "Type";
- "unit";
- "val";
- "with";
- ]
- | Coq ->
- [
- "assert";
- "Arguments";
- "Axiom";
- "char_of_byte";
- "Check";
- "Declare";
- "Definition";
- "else";
- "End";
- "fun";
- "Fixpoint";
- "if";
- "in";
- "int";
- "Inductive";
- "Import";
- "let";
- "Lemma";
- "match";
- "Module";
- "not";
- "Notation";
- "Proof";
- "Qed";
- "rec";
- "Record";
- "Require";
- "Scope";
- "Search";
- "SearchPattern";
- "Set";
- "then";
- (* [tt] is unit *)
- "tt";
- "type";
- "Type";
- "unit";
- "with";
- ]
- | Lean ->
- [
- "by";
- "class";
- "decreasing_by";
- "def";
- "deriving";
- "do";
- "else";
- "end";
- "for";
- "have";
- "if";
- "inductive";
- "instance";
- "import";
- "let";
- "macro";
- "match";
- "namespace";
- "opaque";
- "open";
- "run_cmd";
- "set_option";
- "simp";
- "structure";
- "syntax";
- "termination_by";
- "then";
- "Type";
- "unsafe";
- "where";
- "with";
- "opaque_defs";
- ]
- | HOL4 ->
- [
- "Axiom";
- "case";
- "Definition";
- "else";
- "End";
- "fix";
- "fix_exec";
- "fn";
- "fun";
- "if";
- "in";
- "int";
- "Inductive";
- "let";
- "of";
- "Proof";
- "QED";
- "then";
- "Theorem";
- ]
- in
- List.concat [ named_unops; named_binops; misc ]
-
-let assumed_adts () : (assumed_ty * string) list =
- match !backend with
- | Lean ->
- [
- (State, "State");
- (Result, "Result");
- (Error, "Error");
- (Fuel, "Nat");
- (Option, "Option");
- (Vec, "Vec");
- (Array, "Array");
- (Slice, "Slice");
- (Str, "Str");
- (Range, "Range");
- ]
- | Coq | FStar ->
- [
- (State, "state");
- (Result, "result");
- (Error, "error");
- (Fuel, "nat");
- (Option, "option");
- (Vec, "vec");
- (Array, "array");
- (Slice, "slice");
- (Str, "str");
- (Range, "range");
- ]
- | HOL4 ->
- [
- (State, "state");
- (Result, "result");
- (Error, "error");
- (Fuel, "num");
- (Option, "option");
- (Vec, "vec");
- (Array, "array");
- (Slice, "slice");
- (Str, "str");
- (Range, "range");
- ]
-
-let assumed_struct_constructors () : (assumed_ty * string) list =
- match !backend with
- | Lean -> [ (Range, "Range.mk"); (Array, "Array.make") ]
- | Coq -> [ (Range, "mk_range"); (Array, "mk_array") ]
- | FStar -> [ (Range, "Mkrange"); (Array, "mk_array") ]
- | HOL4 -> [ (Range, "mk_range"); (Array, "mk_array") ]
-
-let assumed_variants () : (assumed_ty * VariantId.id * string) list =
- match !backend with
- | FStar ->
- [
- (Result, result_return_id, "Return");
- (Result, result_fail_id, "Fail");
- (Error, error_failure_id, "Failure");
- (Error, error_out_of_fuel_id, "OutOfFuel");
- (* No Fuel::Zero on purpose *)
- (* No Fuel::Succ on purpose *)
- (Option, option_some_id, "Some");
- (Option, option_none_id, "None");
- ]
- | Coq ->
- [
- (Result, result_return_id, "Return");
- (Result, result_fail_id, "Fail_");
- (Error, error_failure_id, "Failure");
- (Error, error_out_of_fuel_id, "OutOfFuel");
- (Fuel, fuel_zero_id, "O");
- (Fuel, fuel_succ_id, "S");
- (Option, option_some_id, "Some");
- (Option, option_none_id, "None");
- ]
- | Lean ->
- [
- (Result, result_return_id, "ret");
- (Result, result_fail_id, "fail");
- (Error, error_failure_id, "panic");
- (* No Fuel::Zero on purpose *)
- (* No Fuel::Succ on purpose *)
- (Option, option_some_id, "some");
- (Option, option_none_id, "none");
- ]
- | HOL4 ->
- [
- (Result, result_return_id, "Return");
- (Result, result_fail_id, "Fail");
- (Error, error_failure_id, "Failure");
- (* No Fuel::Zero on purpose *)
- (* No Fuel::Succ on purpose *)
- (Option, option_some_id, "SOME");
- (Option, option_none_id, "NONE");
- ]
-
-let assumed_llbc_functions () :
- (A.assumed_fun_id * T.RegionGroupId.id option * string) list =
- let rg0 = Some T.RegionGroupId.zero in
- match !backend with
- | FStar | Coq | HOL4 ->
- [
- (Replace, None, "mem_replace_fwd");
- (Replace, rg0, "mem_replace_back");
- (VecNew, None, "vec_new");
- (VecPush, None, "vec_push_fwd") (* Shouldn't be used *);
- (VecPush, rg0, "vec_push_back");
- (VecInsert, None, "vec_insert_fwd") (* Shouldn't be used *);
- (VecInsert, rg0, "vec_insert_back");
- (VecLen, None, "vec_len");
- (VecIndex, None, "vec_index_fwd");
- (VecIndex, rg0, "vec_index_back") (* shouldn't be used *);
- (VecIndexMut, None, "vec_index_mut_fwd");
- (VecIndexMut, rg0, "vec_index_mut_back");
- (ArrayIndexShared, None, "array_index_shared");
- (ArrayIndexMut, None, "array_index_mut_fwd");
- (ArrayIndexMut, rg0, "array_index_mut_back");
- (ArrayToSliceShared, None, "array_to_slice_shared");
- (ArrayToSliceMut, None, "array_to_slice_mut_fwd");
- (ArrayToSliceMut, rg0, "array_to_slice_mut_back");
- (ArraySubsliceShared, None, "array_subslice_shared");
- (ArraySubsliceMut, None, "array_subslice_mut_fwd");
- (ArraySubsliceMut, rg0, "array_subslice_mut_back");
- (SliceIndexShared, None, "slice_index_shared");
- (SliceIndexMut, None, "slice_index_mut_fwd");
- (SliceIndexMut, rg0, "slice_index_mut_back");
- (SliceSubsliceShared, None, "slice_subslice_shared");
- (SliceSubsliceMut, None, "slice_subslice_mut_fwd");
- (SliceSubsliceMut, rg0, "slice_subslice_mut_back");
- (SliceLen, None, "slice_len");
- ]
- | Lean ->
- [
- (Replace, None, "mem.replace");
- (Replace, rg0, "mem.replace_back");
- (VecNew, None, "Vec.new");
- (VecPush, None, "Vec.push_fwd") (* Shouldn't be used *);
- (VecPush, rg0, "Vec.push");
- (VecInsert, None, "Vec.insert_fwd") (* Shouldn't be used *);
- (VecInsert, rg0, "Vec.insert");
- (VecLen, None, "Vec.len");
- (VecIndex, None, "Vec.index_shared");
- (VecIndex, rg0, "Vec.index_shared_back") (* shouldn't be used *);
- (VecIndexMut, None, "Vec.index_mut");
- (VecIndexMut, rg0, "Vec.index_mut_back");
- (ArrayIndexShared, None, "Array.index_shared");
- (ArrayIndexMut, None, "Array.index_mut");
- (ArrayIndexMut, rg0, "Array.index_mut_back");
- (ArrayToSliceShared, None, "Array.to_slice_shared");
- (ArrayToSliceMut, None, "Array.to_slice_mut");
- (ArrayToSliceMut, rg0, "Array.to_slice_mut_back");
- (ArraySubsliceShared, None, "Array.subslice_shared");
- (ArraySubsliceMut, None, "Array.subslice_mut");
- (ArraySubsliceMut, rg0, "Array.subslice_mut_back");
- (SliceIndexShared, None, "Slice.index_shared");
- (SliceIndexMut, None, "Slice.index_mut");
- (SliceIndexMut, rg0, "Slice.index_mut_back");
- (SliceSubsliceShared, None, "Slice.subslice_shared");
- (SliceSubsliceMut, None, "Slice.subslice_mut");
- (SliceSubsliceMut, rg0, "Slice.subslice_mut_back");
- (SliceLen, None, "Slice.len");
- ]
-
-let assumed_pure_functions () : (pure_assumed_fun_id * string) list =
- match !backend with
- | FStar ->
- [
- (Return, "return");
- (Fail, "fail");
- (Assert, "massert");
- (FuelDecrease, "decrease");
- (FuelEqZero, "is_zero");
- ]
- | Coq ->
- (* We don't provide [FuelDecrease] and [FuelEqZero] on purpose *)
- [ (Return, "return_"); (Fail, "fail_"); (Assert, "massert") ]
- | Lean ->
- (* We don't provide [FuelDecrease] and [FuelEqZero] on purpose *)
- [ (Return, "return"); (Fail, "fail_"); (Assert, "massert") ]
- | HOL4 ->
- (* We don't provide [FuelDecrease] and [FuelEqZero] on purpose *)
- [ (Return, "return"); (Fail, "fail"); (Assert, "massert") ]
-
-let names_map_init () : names_map_init =
- {
- keywords = keywords ();
- assumed_adts = assumed_adts ();
- assumed_structs = assumed_struct_constructors ();
- assumed_variants = assumed_variants ();
- assumed_llbc_functions = assumed_llbc_functions ();
- assumed_pure_functions = assumed_pure_functions ();
- }
-
-let extract_unop (extract_expr : bool -> texpression -> unit)
- (fmt : F.formatter) (inside : bool) (unop : unop) (arg : texpression) : unit
- =
- match unop with
- | Not | Neg _ ->
- let unop = unop_name unop in
- if inside then F.pp_print_string fmt "(";
- F.pp_print_string fmt unop;
- F.pp_print_space fmt ();
- extract_expr true arg;
- if inside then F.pp_print_string fmt ")"
- | Cast (src, tgt) -> (
- (* HOL4 has a special treatment: because it doesn't support dependent
- types, we don't have a specific operator for the cast *)
- match !backend with
- | HOL4 ->
- (* Casting, say, an u32 to an i32 would be done as follows:
- {[
- mk_i32 (u32_to_int x)
- ]}
- *)
- if inside then F.pp_print_string fmt "(";
- F.pp_print_string fmt ("mk_" ^ int_name tgt);
- F.pp_print_space fmt ();
- F.pp_print_string fmt "(";
- F.pp_print_string fmt (int_name src ^ "_to_int");
- F.pp_print_space fmt ();
- extract_expr true arg;
- F.pp_print_string fmt ")";
- if inside then F.pp_print_string fmt ")"
- | FStar | Coq | Lean ->
- (* Rem.: the source type is an implicit parameter *)
- if inside then F.pp_print_string fmt "(";
- let cast_str =
- match !backend with
- | Coq | FStar -> "scalar_cast"
- | Lean -> (* TODO: I8.cast, I16.cast, etc.*) "Scalar.cast"
- | HOL4 -> raise (Failure "Unreachable")
- in
- F.pp_print_string fmt cast_str;
- F.pp_print_space fmt ();
- if !backend <> Lean then (
- F.pp_print_string fmt
- (StringUtils.capitalize_first_letter
- (PrintPure.integer_type_to_string src));
- F.pp_print_space fmt ());
- if !backend = Lean then F.pp_print_string fmt ("." ^ int_name tgt)
- else
- F.pp_print_string fmt
- (StringUtils.capitalize_first_letter
- (PrintPure.integer_type_to_string tgt));
- F.pp_print_space fmt ();
- extract_expr true arg;
- if inside then F.pp_print_string fmt ")")
-
-(** [extract_expr] : the boolean argument is [inside] *)
-let extract_binop (extract_expr : bool -> texpression -> unit)
- (fmt : F.formatter) (inside : bool) (binop : E.binop)
- (int_ty : integer_type) (arg0 : texpression) (arg1 : texpression) : unit =
- if inside then F.pp_print_string fmt "(";
- (* Some binary operations have a special notation depending on the backend *)
- (match (!backend, binop) with
- | HOL4, (Eq | Ne)
- | (FStar | Coq | Lean), (Eq | Lt | Le | Ne | Ge | Gt)
- | Lean, (Div | Rem | Add | Sub | Mul) ->
- let binop =
- match binop with
- | Eq -> "="
- | Lt -> "<"
- | Le -> "<="
- | Ne -> if !backend = Lean then "!=" else "<>"
- | Ge -> ">="
- | Gt -> ">"
- | Div -> "/"
- | Rem -> "%"
- | Add -> "+"
- | Sub -> "-"
- | Mul -> "*"
- | _ -> raise (Failure "Unreachable")
- in
- let binop =
- match !backend with FStar | Lean | HOL4 -> binop | Coq -> "s" ^ binop
- in
- extract_expr false arg0;
- F.pp_print_space fmt ();
- F.pp_print_string fmt binop;
- F.pp_print_space fmt ();
- extract_expr false arg1
- | _, (Lt | Le | Ge | Gt | Div | Rem | Add | Sub | Mul) ->
- let binop = named_binop_name binop int_ty in
- F.pp_print_string fmt binop;
- F.pp_print_space fmt ();
- extract_expr true arg0;
- F.pp_print_space fmt ();
- extract_expr true arg1
- | _, (BitXor | BitAnd | BitOr | Shl | Shr) -> raise Unimplemented);
- if inside then F.pp_print_string fmt ")"
-
-let type_decl_kind_to_qualif (kind : decl_kind)
- (type_kind : type_decl_kind option) : string option =
- match !backend with
- | FStar -> (
- match kind with
- | SingleNonRec -> Some "type"
- | SingleRec -> Some "type"
- | MutRecFirst -> Some "type"
- | MutRecInner -> Some "and"
- | MutRecLast -> Some "and"
- | Assumed -> Some "assume type"
- | Declared -> Some "val")
- | Coq -> (
- match (kind, type_kind) with
- | SingleNonRec, Some Enum -> Some "Inductive"
- | SingleNonRec, Some Struct -> Some "Record"
- | (SingleRec | MutRecFirst), Some _ -> Some "Inductive"
- | (MutRecInner | MutRecLast), Some _ ->
- (* Coq doesn't support groups of mutually recursive definitions which mix
- * records and inducties: we convert everything to records if this happens
- *)
- Some "with"
- | (Assumed | Declared), None -> Some "Axiom"
- | _ -> raise (Failure "Unexpected"))
- | Lean -> (
- match kind with
- | SingleNonRec ->
- if type_kind = Some Struct then Some "structure" else Some "inductive"
- | SingleRec -> Some "inductive"
- | MutRecFirst -> Some "inductive"
- | MutRecInner -> Some "inductive"
- | MutRecLast -> Some "inductive"
- | Assumed -> Some "axiom"
- | Declared -> Some "axiom")
- | HOL4 -> None
-
-let fun_decl_kind_to_qualif (kind : decl_kind) : string option =
- match !backend with
- | FStar -> (
- match kind with
- | SingleNonRec -> Some "let"
- | SingleRec -> Some "let rec"
- | MutRecFirst -> Some "let rec"
- | MutRecInner -> Some "and"
- | MutRecLast -> Some "and"
- | Assumed -> Some "assume val"
- | Declared -> Some "val")
- | Coq -> (
- match kind with
- | SingleNonRec -> Some "Definition"
- | SingleRec -> Some "Fixpoint"
- | MutRecFirst -> Some "Fixpoint"
- | MutRecInner -> Some "with"
- | MutRecLast -> Some "with"
- | Assumed -> Some "Axiom"
- | Declared -> Some "Axiom")
- | Lean -> (
- match kind with
- | SingleNonRec -> Some "def"
- | SingleRec -> Some "divergent def"
- | MutRecFirst -> Some "mutual divergent def"
- | MutRecInner -> Some "divergent def"
- | MutRecLast -> Some "divergent def"
- | Assumed -> Some "axiom"
- | Declared -> Some "axiom")
- | HOL4 -> None
-
-(** The type of types.
-
- TODO: move inside the formatter?
- *)
-let type_keyword () =
- match !backend with
- | FStar -> "Type0"
- | Coq | Lean -> "Type"
- | HOL4 -> raise (Failure "Unexpected")
-
-(**
- [ctx]: we use the context to lookup type definitions, to retrieve type names.
- This is used to compute variable names, when they have no basenames: in this
- case we use the first letter of the type name.
-
- [variant_concatenate_type_name]: if true, add the type name as a prefix
- to the variant names.
- Ex.:
- In Rust:
- {[
- enum List = {
- Cons(u32, Box<List>),x
- Nil,
- }
- ]}
-
- F*, if option activated:
- {[
- type list =
- | ListCons : u32 -> list -> list
- | ListNil : list
- ]}
-
- F*, if option not activated:
- {[
- type list =
- | Cons : u32 -> list -> list
- | Nil : list
- ]}
-
- Rk.: this should be true by default, because in Rust all the variant names
- are actively uniquely identifier by the type name [List::Cons(...)], while
- in other languages it is not necessarily the case, and thus clashes can mess
- up type checking. Note that some languages actually forbids the name clashes
- (it is the case of F* ).
- *)
-let mk_formatter (ctx : trans_ctx) (crate_name : string)
- (variant_concatenate_type_name : bool) : formatter =
- let int_name = int_name in
-
- (* Prepare a name.
- * The first id elem is always the crate: if it is the local crate,
- * we remove it.
- * We also remove all the disambiguators, then convert everything to strings.
- * **Rmk:** because we remove the disambiguators, there may be name collisions
- * (which is ok, because we check for name collisions and fail if there is any).
- *)
- let get_name (name : name) : string list =
- (* Rmk.: initially we only filtered the disambiguators equal to 0 *)
- let name = Names.filter_disambiguators name in
- match name with
- | Ident crate :: name ->
- let name = if crate = crate_name then name else Ident crate :: name in
- let name =
- List.map
- (function
- | Names.Ident s -> s
- | Disambiguator d -> Names.Disambiguator.to_string d)
- name
- in
- name
- | _ ->
- raise (Failure ("Unexpected name shape: " ^ Print.name_to_string name))
- in
- let get_type_name = get_name in
- let type_name_to_camel_case name =
- let name = get_type_name name in
- let name = List.map to_camel_case name in
- String.concat "" name
- in
- let type_name_to_snake_case name =
- let name = get_type_name name in
- let name = List.map to_snake_case name in
- let name = String.concat "_" name in
- match !backend with
- | FStar | Lean | HOL4 -> name
- | Coq -> capitalize_first_letter name
- in
- let type_name name =
- match !backend with
- | FStar | Coq | HOL4 -> type_name_to_snake_case name ^ "_t"
- | Lean -> String.concat "." (get_type_name name)
- in
- let field_name (def_name : name) (field_id : FieldId.id)
- (field_name : string option) : string =
- let field_name =
- match field_name with
- | Some field_name -> field_name
- | None -> FieldId.to_string field_id
- in
- if !Config.record_fields_short_names then field_name
- else
- let def_name = type_name_to_snake_case def_name ^ "_" in
- def_name ^ field_name
- in
- let variant_name (def_name : name) (variant : string) : string =
- match !backend with
- | FStar | Coq | HOL4 ->
- let variant = to_camel_case variant in
- if variant_concatenate_type_name then
- type_name_to_camel_case def_name ^ variant
- else variant
- | Lean -> variant
- in
- let struct_constructor (basename : name) : string =
- let tname = type_name basename in
- let prefix =
- match !backend with FStar -> "Mk" | Coq | HOL4 -> "mk" | Lean -> ""
- in
- let suffix =
- match !backend with FStar | Coq | HOL4 -> "" | Lean -> ".mk"
- in
- prefix ^ tname ^ suffix
- in
- let get_fun_name fname =
- let fname = get_name fname in
- (* TODO: don't convert to snake case for Coq, HOL4, F* *)
- match !backend with
- | FStar | Coq | HOL4 -> String.concat "_" (List.map to_snake_case fname)
- | Lean -> String.concat "." fname
- in
- let global_name (name : global_name) : string =
- (* Converting to snake case also lowercases the letters (in Rust, global
- * names are written in capital letters). *)
- let parts = List.map to_snake_case (get_name name) in
- String.concat "_" parts
- in
- let fun_name (fname : fun_name) (num_loops : int) (loop_id : LoopId.id option)
- (num_rgs : int) (rg : region_group_info option) (filter_info : bool * int)
- : string =
- let fname = get_fun_name fname in
- (* Compute the suffix *)
- let suffix = default_fun_suffix num_loops loop_id num_rgs rg filter_info in
- (* Concatenate *)
- fname ^ suffix
- in
-
- let termination_measure_name (_fid : A.FunDeclId.id) (fname : fun_name)
- (num_loops : int) (loop_id : LoopId.id option) : string =
- let fname = get_fun_name fname in
- let lp_suffix = default_fun_loop_suffix num_loops loop_id in
- (* Compute the suffix *)
- let suffix =
- match !Config.backend with
- | FStar -> "_decreases"
- | Lean -> "_terminates"
- | Coq | HOL4 -> raise (Failure "Unexpected")
- in
- (* Concatenate *)
- fname ^ lp_suffix ^ suffix
- in
-
- let decreases_proof_name (_fid : A.FunDeclId.id) (fname : fun_name)
- (num_loops : int) (loop_id : LoopId.id option) : string =
- let fname = get_fun_name fname in
- let lp_suffix = default_fun_loop_suffix num_loops loop_id in
- (* Compute the suffix *)
- let suffix =
- match !Config.backend with
- | Lean -> "_decreases"
- | FStar | Coq | HOL4 -> raise (Failure "Unexpected")
- in
- (* Concatenate *)
- fname ^ lp_suffix ^ suffix
- in
-
- let opaque_pre () =
- match !Config.backend with
- | FStar | Coq | HOL4 -> ""
- | Lean -> if !Config.wrap_opaque_in_sig then "opaque_defs." else ""
- in
-
- let var_basename (_varset : StringSet.t) (basename : string option) (ty : ty)
- : string =
- (* If there is a basename, we use it *)
- match basename with
- | Some basename ->
- (* This should be a no-op *)
- to_snake_case basename
- | None -> (
- (* No basename: we use the first letter of the type *)
- match ty with
- | Adt (type_id, tys, _) -> (
- match type_id with
- | Tuple ->
- (* The "pair" case is frequent enough to have its special treatment *)
- if List.length tys = 2 then "p" else "t"
- | Assumed Result -> "r"
- | Assumed Error -> ConstStrings.error_basename
- | Assumed Fuel -> ConstStrings.fuel_basename
- | Assumed Option -> "opt"
- | Assumed Vec -> "v"
- | Assumed Array -> "a"
- | Assumed Slice -> "s"
- | Assumed Str -> "s"
- | Assumed Range -> "r"
- | Assumed State -> ConstStrings.state_basename
- | AdtId adt_id ->
- let def =
- TypeDeclId.Map.find adt_id ctx.type_context.type_decls
- in
- (* We do the following:
- * - compute the type name, and retrieve the last ident
- * - convert this to snake case
- * - take the first letter of every "letter group"
- * Ex.: ["hashmap"; "HashMap"] ~~> "HashMap" -> "hash_map" -> "hm"
- *)
- (* Thename shouldn't be empty, and its last element should
- * be an ident *)
- let cl = List.nth def.name (List.length def.name - 1) in
- let cl = to_snake_case (Names.as_ident cl) in
- let cl = String.split_on_char '_' cl in
- let cl = List.filter (fun s -> String.length s > 0) cl in
- assert (List.length cl > 0);
- let cl = List.map (fun s -> s.[0]) cl in
- StringUtils.string_of_chars cl)
- | TypeVar _ -> (
- (* TODO: use "t" also for F* *)
- match !backend with
- | FStar -> "x" (* lacking inspiration here... *)
- | Coq | Lean | HOL4 -> "t" (* lacking inspiration here... *))
- | Literal lty -> (
- match lty with Bool -> "b" | Char -> "c" | Integer _ -> "i")
- | Arrow _ -> "f")
- in
- let type_var_basename (_varset : StringSet.t) (basename : string) : string =
- (* Rust type variables are snake-case and start with a capital letter *)
- match !backend with
- | FStar ->
- (* This is *not* a no-op: this removes the capital letter *)
- to_snake_case basename
- | HOL4 ->
- (* In HOL4, type variable names must start with "'" *)
- "'" ^ to_snake_case basename
- | Coq | Lean -> basename
- in
- let const_generic_var_basename (_varset : StringSet.t) (basename : string) :
- string =
- (* Rust type variables are snake-case and start with a capital letter *)
- match !backend with
- | FStar | HOL4 ->
- (* This is *not* a no-op: this removes the capital letter *)
- to_snake_case basename
- | Coq | Lean -> basename
- in
- let append_index (basename : string) (i : int) : string =
- basename ^ string_of_int i
- in
-
- let extract_literal (fmt : F.formatter) (inside : bool) (cv : literal) : unit
- =
- match cv with
- | Scalar sv -> (
- match !backend with
- | FStar -> F.pp_print_string fmt (Z.to_string sv.PV.value)
- | Coq | HOL4 ->
- let print_brackets = inside && !backend = HOL4 in
- if print_brackets then F.pp_print_string fmt "(";
- (match !backend with
- | Coq -> ()
- | HOL4 ->
- F.pp_print_string fmt ("int_to_" ^ int_name sv.PV.int_ty);
- F.pp_print_space fmt ()
- | _ -> raise (Failure "Unreachable"));
- (* We need to add parentheses if the value is negative *)
- if sv.PV.value >= Z.of_int 0 then
- F.pp_print_string fmt (Z.to_string sv.PV.value)
- else F.pp_print_string fmt ("(" ^ Z.to_string sv.PV.value ^ ")");
- (match !backend with
- | Coq -> F.pp_print_string fmt ("%" ^ int_name sv.PV.int_ty)
- | HOL4 -> ()
- | _ -> raise (Failure "Unreachable"));
- if print_brackets then F.pp_print_string fmt ")"
- | Lean ->
- F.pp_print_string fmt "(";
- F.pp_print_string fmt (int_name sv.int_ty);
- F.pp_print_string fmt ".ofInt ";
- (* Something very annoying: negated values like `-3` are
- ambiguous in Lean because of conversions, so we have to
- be extremely explicit with negative numbers.
- *)
- if Z.lt sv.value Z.zero then (
- F.pp_print_string fmt "(";
- F.pp_print_string fmt "-";
- F.pp_print_string fmt "(";
- Z.pp_print fmt (Z.neg sv.value);
- F.pp_print_string fmt ":Int";
- F.pp_print_string fmt ")";
- F.pp_print_string fmt ")")
- else Z.pp_print fmt sv.value;
- F.pp_print_string fmt ")")
- | Bool b ->
- let b =
- match !backend with
- | HOL4 -> if b then "T" else "F"
- | Coq | FStar | Lean -> if b then "true" else "false"
- in
- F.pp_print_string fmt b
- | Char c -> (
- match !backend with
- | HOL4 ->
- (* [#"a"] is a notation for [CHR 97] (97 is the ASCII code for 'a') *)
- F.pp_print_string fmt ("#\"" ^ String.make 1 c ^ "\"")
- | FStar | Lean -> F.pp_print_string fmt ("'" ^ String.make 1 c ^ "'")
- | Coq ->
- if inside then F.pp_print_string fmt "(";
- F.pp_print_string fmt "char_of_byte";
- F.pp_print_space fmt ();
- (* Convert the the char to ascii *)
- let c =
- let i = Char.code c in
- let x0 = i / 16 in
- let x1 = i mod 16 in
- "Coq.Init.Byte.x" ^ string_of_int x0 ^ string_of_int x1
- in
- F.pp_print_string fmt c;
- if inside then F.pp_print_string fmt ")")
- in
- let bool_name = if !backend = Lean then "Bool" else "bool" in
- let char_name = if !backend = Lean then "Char" else "char" in
- let str_name = if !backend = Lean then "String" else "string" in
- {
- bool_name;
- char_name;
- int_name;
- str_name;
- type_decl_kind_to_qualif;
- fun_decl_kind_to_qualif;
- field_name;
- variant_name;
- struct_constructor;
- type_name;
- global_name;
- fun_name;
- termination_measure_name;
- decreases_proof_name;
- opaque_pre;
- var_basename;
- type_var_basename;
- const_generic_var_basename;
- append_index;
- extract_literal;
- extract_unop;
- extract_binop;
- }
-
-let mk_formatter_and_names_map (ctx : trans_ctx) (crate_name : string)
- (variant_concatenate_type_name : bool) : formatter * names_map =
- let fmt = mk_formatter ctx crate_name variant_concatenate_type_name in
- let names_map = initialize_names_map fmt (names_map_init ()) in
- (fmt, names_map)
-
-let is_single_opaque_fun_decl_group (dg : Pure.fun_decl list) : bool =
- match dg with [ d ] -> d.body = None | _ -> false
-
-let is_single_opaque_type_decl_group (dg : Pure.type_decl list) : bool =
- match dg with [ d ] -> d.kind = Opaque | _ -> false
-
-let is_empty_record_type_decl (d : Pure.type_decl) : bool = d.kind = Struct []
-
-let is_empty_record_type_decl_group (dg : Pure.type_decl list) : bool =
- match dg with [ d ] -> is_empty_record_type_decl d | _ -> false
-
-(** In some provers, groups of definitions must be delimited.
-
- - in Coq, *every* group (including singletons) must end with "."
- - in Lean, groups of mutually recursive definitions must end with "end"
- - in HOL4 (in most situations) the whole group must be within a `Define` command
-
- Calls to {!extract_fun_decl} should be inserted between calls to
- {!start_fun_decl_group} and {!end_fun_decl_group}.
-
- TODO: maybe those [{start/end}_decl_group] functions are not that much a good
- idea and we should merge them with the corresponding [extract_decl] functions.
- *)
-let start_fun_decl_group (ctx : extraction_ctx) (fmt : F.formatter)
- (is_rec : bool) (dg : Pure.fun_decl list) =
- match !backend with
- | FStar | Coq | Lean -> ()
- | HOL4 ->
- (* In HOL4, opaque functions have a special treatment *)
- if is_single_opaque_fun_decl_group dg then ()
- else
- let with_opaque_pre = false in
- let compute_fun_def_name (def : Pure.fun_decl) : string =
- ctx_get_local_function with_opaque_pre def.def_id def.loop_id
- def.back_id ctx
- ^ "_def"
- in
- let names = List.map compute_fun_def_name dg in
- (* Add a break before *)
- F.pp_print_break fmt 0 0;
- (* Open the box for the delimiters *)
- F.pp_open_vbox fmt 0;
- (* Open the box for the definitions themselves *)
- F.pp_open_vbox fmt ctx.indent_incr;
- (* Print the delimiters *)
- if is_rec then
- F.pp_print_string fmt
- ("val [" ^ String.concat ", " names ^ "] = DefineDiv ‘")
- else (
- assert (List.length names = 1);
- let name = List.hd names in
- F.pp_print_string fmt ("val " ^ name ^ " = Define ‘"));
- F.pp_print_cut fmt ()
-
-(** See {!start_fun_decl_group}. *)
-let end_fun_decl_group (fmt : F.formatter) (is_rec : bool)
- (dg : Pure.fun_decl list) =
- match !backend with
- | FStar -> ()
- | Coq ->
- (* For aesthetic reasons, we print the Coq end group delimiter directly
- in {!extract_fun_decl}. *)
- ()
- | Lean ->
- (* We must add the "end" keyword to groups of mutually recursive functions *)
- if is_rec && List.length dg > 1 then (
- F.pp_print_cut fmt ();
- F.pp_print_string fmt "end";
- (* Add breaks to insert new lines between definitions *)
- F.pp_print_break fmt 0 0)
- else ()
- | HOL4 ->
- (* In HOL4, opaque functions have a special treatment *)
- if is_single_opaque_fun_decl_group dg then ()
- else (
- (* Close the box for the definitions *)
- F.pp_close_box fmt ();
- (* Print the end delimiter *)
- F.pp_print_cut fmt ();
- F.pp_print_string fmt "’";
- (* Close the box for the delimiters *)
- F.pp_close_box fmt ();
- (* Add breaks to insert new lines between definitions *)
- F.pp_print_break fmt 0 0)
-
-(** See {!start_fun_decl_group}: similar usage, but for the type declarations. *)
-let start_type_decl_group (ctx : extraction_ctx) (fmt : F.formatter)
- (is_rec : bool) (dg : Pure.type_decl list) =
- match !backend with
- | FStar | Coq -> ()
- | Lean ->
- if is_rec && List.length dg > 1 then (
- F.pp_print_space fmt ();
- F.pp_print_string fmt "mutual";
- F.pp_print_space fmt ())
- | HOL4 ->
- (* In HOL4, opaque types and empty records have a special treatment *)
- if
- is_single_opaque_type_decl_group dg
- || is_empty_record_type_decl_group dg
- then ()
- else (
- (* Add a break before *)
- F.pp_print_break fmt 0 0;
- (* Open the box for the delimiters *)
- F.pp_open_vbox fmt 0;
- (* Open the box for the definitions themselves *)
- F.pp_open_vbox fmt ctx.indent_incr;
- (* Print the delimiters *)
- F.pp_print_string fmt "Datatype:";
- F.pp_print_cut fmt ())
-
-(** See {!start_fun_decl_group}. *)
-let end_type_decl_group (fmt : F.formatter) (is_rec : bool)
- (dg : Pure.type_decl list) =
- match !backend with
- | FStar -> ()
- | Coq ->
- (* For aesthetic reasons, we print the Coq end group delimiter directly
- in {!extract_fun_decl}. *)
- ()
- | Lean ->
- (* We must add the "end" keyword to groups of mutually recursive functions *)
- if is_rec && List.length dg > 1 then (
- F.pp_print_cut fmt ();
- F.pp_print_string fmt "end";
- (* Add breaks to insert new lines between definitions *)
- F.pp_print_break fmt 0 0)
- else ()
- | HOL4 ->
- (* In HOL4, opaque types and empty records have a special treatment *)
- if
- is_single_opaque_type_decl_group dg
- || is_empty_record_type_decl_group dg
- then ()
- else (
- (* Close the box for the definitions *)
- F.pp_close_box fmt ();
- (* Print the end delimiter *)
- F.pp_print_cut fmt ();
- F.pp_print_string fmt "End";
- (* Close the box for the delimiters *)
- F.pp_close_box fmt ();
- (* Add breaks to insert new lines between definitions *)
- F.pp_print_break fmt 0 0)
-
-let unit_name () =
- match !backend with Lean -> "Unit" | Coq | FStar | HOL4 -> "unit"
-
-(** Small helper *)
-let extract_arrow (fmt : F.formatter) () : unit =
- if !Config.backend = Lean then F.pp_print_string fmt "→"
- else F.pp_print_string fmt "->"
-
-let extract_const_generic (ctx : extraction_ctx) (fmt : F.formatter)
- (inside : bool) (cg : const_generic) : unit =
- match cg with
- | ConstGenericGlobal id ->
- let s = ctx_get_global ctx.use_opaque_pre id ctx in
- F.pp_print_string fmt s
- | ConstGenericValue v -> ctx.fmt.extract_literal fmt inside v
- | ConstGenericVar id ->
- let s = ctx_get_const_generic_var id ctx in
- F.pp_print_string fmt s
-
-let extract_literal_type (ctx : extraction_ctx) (fmt : F.formatter)
- (ty : literal_type) : unit =
- match ty with
- | Bool -> F.pp_print_string fmt ctx.fmt.bool_name
- | Char -> F.pp_print_string fmt ctx.fmt.char_name
- | Integer int_ty -> F.pp_print_string fmt (ctx.fmt.int_name int_ty)
-
-(** [inside] constrols whether we should add parentheses or not around type
- applications (if [true] we add parentheses).
-
- [no_params_tys]: for all the types inside this set, do not print the type parameters.
- This is used for HOL4. As polymorphism is uniform in HOL4, printing the
- type parameters in the recursive definitions is useless (and actually
- forbidden).
-
- For instance, where in F* we would write:
- {[
- type list a = | Nil : list a | Cons : a -> list a -> list a
- ]}
-
- In HOL4 we would simply write:
- {[
- Datatype:
- list = Nil 'a | Cons 'a list
- End
- ]}
- *)
-let rec extract_ty (ctx : extraction_ctx) (fmt : F.formatter)
- (no_params_tys : TypeDeclId.Set.t) (inside : bool) (ty : ty) : unit =
- let extract_rec = extract_ty ctx fmt no_params_tys in
- match ty with
- | Adt (type_id, tys, cgs) -> (
- let has_params = tys <> [] || cgs <> [] in
- match type_id with
- | Tuple ->
- (* This is a bit annoying, but in F*/Coq/HOL4 [()] is not the unit type:
- * we have to write [unit]... *)
- if tys = [] then F.pp_print_string fmt (unit_name ())
- else (
- F.pp_print_string fmt "(";
- Collections.List.iter_link
- (fun () ->
- F.pp_print_space fmt ();
- let product =
- match !backend with
- | FStar -> "&"
- | Coq -> "*"
- | Lean -> "×"
- | HOL4 -> "#"
- in
- F.pp_print_string fmt product;
- F.pp_print_space fmt ())
- (extract_rec true) tys;
- F.pp_print_string fmt ")")
- | AdtId _ | Assumed _ -> (
- (* HOL4 behaves differently. Where in Coq/FStar/Lean we would write:
- `tree a b`
-
- In HOL4 we would write:
- `('a, 'b) tree`
- *)
- let with_opaque_pre = false in
- match !backend with
- | FStar | Coq | Lean ->
- let print_paren = inside && has_params in
- if print_paren then F.pp_print_string fmt "(";
- (* TODO: for now, only the opaque *functions* are extracted in the
- opaque module. The opaque *types* are assumed. *)
- F.pp_print_string fmt (ctx_get_type with_opaque_pre type_id ctx);
- if tys <> [] then (
- F.pp_print_space fmt ();
- Collections.List.iter_link (F.pp_print_space fmt)
- (extract_rec true) tys);
- if cgs <> [] then (
- F.pp_print_space fmt ();
- Collections.List.iter_link (F.pp_print_space fmt)
- (extract_const_generic ctx fmt true)
- cgs);
- if print_paren then F.pp_print_string fmt ")"
- | HOL4 ->
- (* Const generics are unsupported in HOL4 *)
- assert (cgs = []);
- let print_tys =
- match type_id with
- | AdtId id -> not (TypeDeclId.Set.mem id no_params_tys)
- | Assumed _ -> true
- | _ -> raise (Failure "Unreachable")
- in
- if tys <> [] && print_tys then (
- let print_paren = List.length tys > 1 in
- if print_paren then F.pp_print_string fmt "(";
- Collections.List.iter_link
- (fun () ->
- F.pp_print_string fmt ",";
- F.pp_print_space fmt ())
- (extract_rec true) tys;
- if print_paren then F.pp_print_string fmt ")";
- F.pp_print_space fmt ());
- F.pp_print_string fmt (ctx_get_type with_opaque_pre type_id ctx)))
- | TypeVar vid -> F.pp_print_string fmt (ctx_get_type_var vid ctx)
- | Literal lty -> extract_literal_type ctx fmt lty
- | Arrow (arg_ty, ret_ty) ->
- if inside then F.pp_print_string fmt "(";
- extract_rec false arg_ty;
- F.pp_print_space fmt ();
- extract_arrow fmt ();
- F.pp_print_space fmt ();
- extract_rec false ret_ty;
- if inside then F.pp_print_string fmt ")"
-
-(** Compute the names for all the top-level identifiers used in a type
- definition (type name, variant names, field names, etc. but not type
- parameters).
-
- We need to do this preemptively, beforce extracting any definition,
- because of recursive definitions.
- *)
-let extract_type_decl_register_names (ctx : extraction_ctx) (def : type_decl) :
- extraction_ctx =
- (* Compute and register the type def name *)
- let ctx = ctx_add_type_decl def ctx in
- (* Compute and register:
- * - the variant names, if this is an enumeration
- * - the field names, if this is a structure
- *)
- let ctx =
- match def.kind with
- | Struct fields ->
- (* Add the fields *)
- let ctx =
- fst
- (ctx_add_fields def (FieldId.mapi (fun id f -> (id, f)) fields) ctx)
- in
- (* Add the constructor name *)
- fst (ctx_add_struct def ctx)
- | Enum variants ->
- fst
- (ctx_add_variants def
- (VariantId.mapi (fun id v -> (id, v)) variants)
- ctx)
- | Opaque ->
- (* Nothing to do *)
- ctx
- in
- (* Return *)
- ctx
-
-(** Print the variants *)
-let extract_type_decl_variant (ctx : extraction_ctx) (fmt : F.formatter)
- (type_decl_group : TypeDeclId.Set.t) (type_name : string)
- (type_params : string list) (cg_params : string list) (cons_name : string)
- (fields : field list) : unit =
- F.pp_print_space fmt ();
- (* variant box *)
- F.pp_open_hvbox fmt ctx.indent_incr;
- (* [| Cons :]
- * Note that we really don't want any break above so we print everything
- * at once. *)
- let opt_colon = if !backend <> HOL4 then " :" else "" in
- F.pp_print_string fmt ("| " ^ cons_name ^ opt_colon);
- let print_field (fid : FieldId.id) (f : field) (ctx : extraction_ctx) :
- extraction_ctx =
- F.pp_print_space fmt ();
- (* Open the field box *)
- F.pp_open_box fmt ctx.indent_incr;
- (* Print the field names, if the backend accepts it.
- * [ x :]
- * Note that when printing fields, we register the field names as
- * *variables*: they don't need to be unique at the top level. *)
- let ctx =
- match !backend with
- | FStar -> (
- match f.field_name with
- | None -> ctx
- | Some field_name ->
- let var_id = VarId.of_int (FieldId.to_int fid) in
- let field_name =
- ctx.fmt.var_basename ctx.names_map.names_set (Some field_name)
- f.field_ty
- in
- let ctx, field_name = ctx_add_var field_name var_id ctx in
- F.pp_print_string fmt (field_name ^ " :");
- F.pp_print_space fmt ();
- ctx)
- | Coq | Lean | HOL4 -> ctx
- in
- (* Print the field type *)
- let inside = !backend = HOL4 in
- extract_ty ctx fmt type_decl_group inside f.field_ty;
- (* Print the arrow [->] *)
- if !backend <> HOL4 then (
- F.pp_print_space fmt ();
- extract_arrow fmt ());
- (* Close the field box *)
- F.pp_close_box fmt ();
- (* Return *)
- ctx
- in
- (* Print the fields *)
- let fields = FieldId.mapi (fun fid f -> (fid, f)) fields in
- let _ =
- List.fold_left (fun ctx (fid, f) -> print_field fid f ctx) ctx fields
- in
- (* Sanity check: HOL4 doesn't support const generics *)
- assert (cg_params = [] || !backend <> HOL4);
- (* Print the final type *)
- if !backend <> HOL4 then (
- F.pp_print_space fmt ();
- F.pp_open_hovbox fmt 0;
- F.pp_print_string fmt type_name;
- List.iter
- (fun p ->
- F.pp_print_space fmt ();
- F.pp_print_string fmt p)
- (List.append type_params cg_params);
- F.pp_close_box fmt ());
- (* Close the variant box *)
- F.pp_close_box fmt ()
-
-(* TODO: we don' need the [def_name] paramter: it can be retrieved from the context *)
-let extract_type_decl_enum_body (ctx : extraction_ctx) (fmt : F.formatter)
- (type_decl_group : TypeDeclId.Set.t) (def : type_decl) (def_name : string)
- (type_params : string list) (cg_params : string list)
- (variants : variant list) : unit =
- (* We want to generate a definition which looks like this (taking F* as example):
- {[
- type list a = | Cons : a -> list a -> list a | Nil : list a
- ]}
-
- If there isn't enough space on one line:
- {[
- type s =
- | Cons : a -> list a -> list a
- | Nil : list a
- ]}
-
- And if we need to write the type of a variant on several lines:
- {[
- type s =
- | Cons :
- a ->
- list a ->
- list a
- | Nil : list a
- ]}
-
- Finally, it is possible to give names to the variant fields in Rust.
- In this situation, we generate a definition like this:
- {[
- type s =
- | Cons : hd:a -> tl:list a -> list a
- | Nil : list a
- ]}
-
- Note that we already printed: [type s =]
- *)
- let print_variant _variant_id (v : variant) =
- (* We don't lookup the name, because it may have a prefix for the type
- id (in the case of Lean) *)
- let cons_name = ctx.fmt.variant_name def.name v.variant_name in
- let fields = v.fields in
- extract_type_decl_variant ctx fmt type_decl_group def_name type_params
- cg_params cons_name fields
- in
- (* Print the variants *)
- let variants = VariantId.mapi (fun vid v -> (vid, v)) variants in
- List.iter (fun (vid, v) -> print_variant vid v) variants
-
-let extract_type_decl_struct_body (ctx : extraction_ctx) (fmt : F.formatter)
- (type_decl_group : TypeDeclId.Set.t) (kind : decl_kind) (def : type_decl)
- (type_params : string list) (cg_params : string list) (fields : field list)
- : unit =
- (* We want to generate a definition which looks like this (taking F* as example):
- {[
- type t = { x : int; y : bool; }
- ]}
-
- If there isn't enough space on one line:
- {[
- type t =
- {
- x : int; y : bool;
- }
- ]}
-
- And if there is even less space:
- {[
- type t =
- {
- x : int;
- y : bool;
- }
- ]}
-
- Also, in case there are no fields, we need to define the type as [unit]
- ([type t = {}] doesn't work in F* ).
-
- Coq:
- ====
- We need to define the constructor name upon defining the struct (record, in Coq).
- The syntex is:
- {[
- Record Foo = mkFoo { x : int; y : bool; }.
- }]
-
- Also, Coq doesn't support groups of mutually recursive inductives and records.
- This is fine, because we can then define records as inductives, and leverage
- the fact that when record fields are accessed, the records are symbolically
- expanded which introduces let bindings of the form: [let RecordCons ... = x in ...].
- As a consequence, we never use the record projectors (unless we reconstruct
- them in the micro passes of course).
-
- HOL4:
- =====
- Type definitions are written as follows:
- {[
- Datatype:
- tree =
- TLeaf 'a
- | TNode node ;
-
- node =
- Node (tree list)
- End
- ]}
- *)
- (* Note that we already printed: [type t =] *)
- let is_rec = decl_is_from_rec_group kind in
- let _ =
- if !backend = FStar && fields = [] then (
- F.pp_print_space fmt ();
- F.pp_print_string fmt (unit_name ()))
- else if !backend = Lean && fields = [] then ()
- (* If the definition is recursive, we may need to extract it as an inductive
- (instead of a record). We start with the "normal" case: we extract it
- as a record. *)
- else if (not is_rec) || (!backend <> Coq && !backend <> Lean) then (
- if !backend <> Lean then F.pp_print_space fmt ();
- (* If Coq: print the constructor name *)
- (* TODO: remove superfluous test not is_rec below *)
- if !backend = Coq && not is_rec then (
- let with_opaque_pre = false in
- F.pp_print_string fmt
- (ctx_get_struct with_opaque_pre (AdtId def.def_id) ctx);
- F.pp_print_string fmt " ");
- (match !backend with
- | Lean -> ()
- | FStar | Coq -> F.pp_print_string fmt "{"
- | HOL4 -> F.pp_print_string fmt "<|");
- F.pp_print_break fmt 1 ctx.indent_incr;
- (* The body itself *)
- (* Open a box for the body *)
- (match !backend with
- | Coq | FStar | HOL4 -> F.pp_open_hvbox fmt 0
- | Lean -> F.pp_open_vbox fmt 0);
- (* Print the fields *)
- let print_field (field_id : FieldId.id) (f : field) : unit =
- let field_name = ctx_get_field (AdtId def.def_id) field_id ctx in
- (* Open a box for the field *)
- F.pp_open_box fmt ctx.indent_incr;
- F.pp_print_string fmt field_name;
- F.pp_print_space fmt ();
- F.pp_print_string fmt ":";
- F.pp_print_space fmt ();
- extract_ty ctx fmt type_decl_group false f.field_ty;
- if !backend <> Lean then F.pp_print_string fmt ";";
- (* Close the box for the field *)
- F.pp_close_box fmt ()
- in
- let fields = FieldId.mapi (fun fid f -> (fid, f)) fields in
- Collections.List.iter_link (F.pp_print_space fmt)
- (fun (fid, f) -> print_field fid f)
- fields;
- (* Close the box for the body *)
- F.pp_close_box fmt ();
- match !backend with
- | Lean -> ()
- | FStar | Coq ->
- F.pp_print_space fmt ();
- F.pp_print_string fmt "}"
- | HOL4 ->
- F.pp_print_space fmt ();
- F.pp_print_string fmt "|>")
- else (
- (* We extract for Coq or Lean, and we have a recursive record, or a record in
- a group of mutually recursive types: we extract it as an inductive type *)
- assert (is_rec && (!backend = Coq || !backend = Lean));
- let with_opaque_pre = false in
- (* Small trick: in Lean we use namespaces, meaning we don't need to prefix
- the constructor name with the name of the type at definition site,
- i.e., instead of generating `inductive Foo := | MkFoo ...` like in Coq
- we generate `inductive Foo := | mk ... *)
- let cons_name =
- if !backend = Lean then "mk"
- else ctx_get_struct with_opaque_pre (AdtId def.def_id) ctx
- in
- let def_name = ctx_get_local_type with_opaque_pre def.def_id ctx in
- extract_type_decl_variant ctx fmt type_decl_group def_name type_params
- cg_params cons_name fields)
- in
- ()
-
-(** Extract a nestable, muti-line comment *)
-let extract_comment (fmt : F.formatter) (sl : string list) : unit =
- (* Delimiters, space after we break a line *)
- let ld, space, rd =
- match !backend with
- | Coq | FStar | HOL4 -> ("(** ", 4, " *)")
- | Lean -> ("/- ", 3, " -/")
- in
- F.pp_open_vbox fmt space;
- F.pp_print_string fmt ld;
- (match sl with
- | [] -> ()
- | s :: sl ->
- F.pp_print_string fmt s;
- List.iter
- (fun s ->
- F.pp_print_space fmt ();
- F.pp_print_string fmt s)
- sl);
- F.pp_print_string fmt rd;
- F.pp_close_box fmt ()
-
-(** Extract a type declaration.
-
- This function is for all type declarations and all backends **at the exception**
- of opaque (assumed/declared) types format4 HOL4.
-
- See {!extract_type_decl}.
- *)
-let extract_type_decl_gen (ctx : extraction_ctx) (fmt : F.formatter)
- (type_decl_group : TypeDeclId.Set.t) (kind : decl_kind) (def : type_decl)
- (extract_body : bool) : unit =
- (* Sanity check *)
- assert (extract_body || !backend <> HOL4);
- let type_kind =
- if extract_body then
- match def.kind with
- | Struct _ -> Some Struct
- | Enum _ -> Some Enum
- | Opaque -> None
- else None
- in
- (* If in Coq and the declaration is opaque, it must have the shape:
- [Axiom Ident : forall (T0 ... Tn : Type) (N0 : ...) ... (Nn : ...), ... -> ... -> ...].
-
- The boolean [is_opaque_coq] is used to detect this case.
- *)
- let is_opaque = type_kind = None in
- let is_opaque_coq = !backend = Coq && is_opaque in
- let use_forall =
- is_opaque_coq && (def.type_params <> [] || def.const_generic_params <> [])
- in
- (* Retrieve the definition name *)
- let with_opaque_pre = false in
- let def_name = ctx_get_local_type with_opaque_pre def.def_id ctx in
- (* Add the type and const generic params - note that we need those bindings only for the
- * body translation (they are not top-level) *)
- let ctx_body, type_params, cg_params =
- ctx_add_type_const_generic_params def.type_params def.const_generic_params
- ctx
- in
- let ty_cg_params = List.append type_params cg_params in
- (* Add a break before *)
- if !backend <> HOL4 || not (decl_is_first_from_group kind) then
- F.pp_print_break fmt 0 0;
- (* Print a comment to link the extracted type to its original rust definition *)
- extract_comment fmt [ "[" ^ Print.name_to_string def.name ^ "]" ];
- F.pp_print_break fmt 0 0;
- (* Open a box for the definition, so that whenever possible it gets printed on
- * one line. Note however that in the case of Lean line breaks are important
- * for parsing: we thus use a hovbox. *)
- (match !backend with
- | Coq | FStar | HOL4 -> F.pp_open_hvbox fmt 0
- | Lean -> F.pp_open_vbox fmt 0);
- (* Open a box for "type TYPE_NAME (TYPE_PARAMS CONST_GEN_PARAMS) =" *)
- F.pp_open_hovbox fmt ctx.indent_incr;
- (* > "type TYPE_NAME" *)
- let qualif = ctx.fmt.type_decl_kind_to_qualif kind type_kind in
- (match qualif with
- | Some qualif -> F.pp_print_string fmt (qualif ^ " " ^ def_name)
- | None -> F.pp_print_string fmt def_name);
- (* HOL4 doesn't support const generics *)
- assert (cg_params = [] || !backend <> HOL4);
- (* Print the type/const generic parameters *)
- if ty_cg_params <> [] && !backend <> HOL4 then (
- if use_forall then (
- F.pp_print_space fmt ();
- F.pp_print_string fmt ":";
- F.pp_print_space fmt ();
- F.pp_print_string fmt "forall");
- (* Print the type parameters *)
- if type_params <> [] then (
- F.pp_print_space fmt ();
- F.pp_print_string fmt "(";
- List.iter
- (fun s ->
- F.pp_print_string fmt s;
- F.pp_print_space fmt ())
- type_params;
- F.pp_print_string fmt ":";
- F.pp_print_space fmt ();
- F.pp_print_string fmt (type_keyword () ^ ")"));
- (* Print the const generic parameters *)
- List.iter
- (fun (var : const_generic_var) ->
- F.pp_print_space fmt ();
- F.pp_print_string fmt "(";
- let n = ctx_get_const_generic_var var.index ctx in
- F.pp_print_string fmt n;
- F.pp_print_space fmt ();
- F.pp_print_string fmt ":";
- F.pp_print_space fmt ();
- extract_literal_type ctx fmt var.ty;
- F.pp_print_string fmt ")")
- def.const_generic_params);
- (* Print the "=" if we extract the body*)
- if extract_body then (
- F.pp_print_space fmt ();
- let eq =
- match !backend with
- | FStar -> "="
- | Coq -> ":="
- | Lean ->
- if type_kind = Some Struct && kind = SingleNonRec then "where"
- else ":="
- | HOL4 -> "="
- in
- F.pp_print_string fmt eq)
- else (
- (* Otherwise print ": Type", unless it is the HOL4 backend (in
- which case we declare the type with `new_type`) *)
- if use_forall then F.pp_print_string fmt ","
- else (
- F.pp_print_space fmt ();
- F.pp_print_string fmt ":");
- F.pp_print_space fmt ();
- F.pp_print_string fmt (type_keyword ()));
- (* Close the box for "type TYPE_NAME (TYPE_PARAMS) =" *)
- F.pp_close_box fmt ();
- (if extract_body then
- match def.kind with
- | Struct fields ->
- extract_type_decl_struct_body ctx_body fmt type_decl_group kind def
- type_params cg_params fields
- | Enum variants ->
- extract_type_decl_enum_body ctx_body fmt type_decl_group def def_name
- type_params cg_params variants
- | Opaque -> raise (Failure "Unreachable"));
- (* Add the definition end delimiter *)
- if !backend = HOL4 && decl_is_not_last_from_group kind then (
- F.pp_print_space fmt ();
- F.pp_print_string fmt ";")
- else if !backend = Coq && decl_is_last_from_group kind then (
- (* This is actually an end of group delimiter. For aesthetic reasons
- we print it here instead of in {!end_type_decl_group}. *)
- F.pp_print_cut fmt ();
- F.pp_print_string fmt ".");
- (* Close the box for the definition *)
- F.pp_close_box fmt ();
- (* Add breaks to insert new lines between definitions *)
- if !backend <> HOL4 || decl_is_not_last_from_group kind then
- F.pp_print_break fmt 0 0
-
-(** Extract an opaque type declaration to HOL4.
-
- Remark (SH): having to treat this specific case separately is very annoying,
- but I could not find a better way.
- *)
-let extract_type_decl_hol4_opaque (ctx : extraction_ctx) (fmt : F.formatter)
- (def : type_decl) : unit =
- (* Retrieve the definition name *)
- let with_opaque_pre = false in
- let def_name = ctx_get_local_type with_opaque_pre def.def_id ctx in
- (* Generic parameters are unsupported *)
- assert (def.const_generic_params = []);
- (* Count the number of parameters *)
- let num_params = List.length def.type_params in
- (* Generate the declaration *)
- F.pp_print_space fmt ();
- F.pp_print_string fmt
- ("val _ = new_type (\"" ^ def_name ^ "\", " ^ string_of_int num_params ^ ")");
- F.pp_print_space fmt ()
-
-(** Extract an empty record type declaration to HOL4.
-
- Empty records are not supported in HOL4, so we extract them as type
- abbreviations to the unit type.
-
- Remark (SH): having to treat this specific case separately is very annoying,
- but I could not find a better way.
- *)
-let extract_type_decl_hol4_empty_record (ctx : extraction_ctx)
- (fmt : F.formatter) (def : type_decl) : unit =
- (* Retrieve the definition name *)
- let with_opaque_pre = false in
- let def_name = ctx_get_local_type with_opaque_pre def.def_id ctx in
- (* Sanity check *)
- assert (def.type_params = []);
- assert (def.const_generic_params = []);
- (* Generate the declaration *)
- F.pp_print_space fmt ();
- F.pp_print_string fmt ("Type " ^ def_name ^ " = “: unit”");
- F.pp_print_space fmt ()
-
-(** Extract a type declaration.
-
- Note that all the names used for extraction should already have been
- registered.
-
- This function should be inserted between calls to {!start_type_decl_group}
- and {!end_type_decl_group}.
- *)
-let extract_type_decl (ctx : extraction_ctx) (fmt : F.formatter)
- (type_decl_group : TypeDeclId.Set.t) (kind : decl_kind) (def : type_decl) :
- unit =
- let extract_body =
- match kind with
- | SingleNonRec | SingleRec | MutRecFirst | MutRecInner | MutRecLast -> true
- | Assumed | Declared -> false
- in
- if extract_body then
- if !backend = HOL4 && is_empty_record_type_decl def then
- extract_type_decl_hol4_empty_record ctx fmt def
- else extract_type_decl_gen ctx fmt type_decl_group kind def extract_body
- else
- match !backend with
- | FStar | Coq | Lean ->
- extract_type_decl_gen ctx fmt type_decl_group kind def extract_body
- | HOL4 -> extract_type_decl_hol4_opaque ctx fmt def
-
-(** Auxiliary function.
-
- Generate [Arguments] instructions in Coq.
- *)
-let extract_type_decl_coq_arguments (ctx : extraction_ctx) (fmt : F.formatter)
- (kind : decl_kind) (decl : type_decl) : unit =
- assert (!backend = Coq);
- (* Generating the [Arguments] instructions is useful only if there are type parameters *)
- if decl.type_params = [] && decl.const_generic_params = [] then ()
- else
- (* Add the type params - note that we need those bindings only for the
- * body translation (they are not top-level) *)
- let _ctx_body, type_params, cg_params =
- ctx_add_type_const_generic_params decl.type_params
- decl.const_generic_params ctx
- in
- (* Auxiliary function to extract an [Arguments Cons {T} _ _.] instruction *)
- let extract_arguments_info (cons_name : string) (fields : 'a list) : unit =
- (* Add a break before *)
- F.pp_print_break fmt 0 0;
- (* Open a box *)
- F.pp_open_hovbox fmt ctx.indent_incr;
- (* Small utility *)
- let print_vars () =
- List.iter
- (fun (var : string) ->
- F.pp_print_space fmt ();
- F.pp_print_string fmt ("{" ^ var ^ "}"))
- (List.append type_params cg_params)
- in
- let print_fields () =
- List.iter
- (fun _ ->
- F.pp_print_space fmt ();
- F.pp_print_string fmt "_")
- fields
- in
- F.pp_print_break fmt 0 0;
- F.pp_print_string fmt "Arguments";
- F.pp_print_space fmt ();
- F.pp_print_string fmt cons_name;
- print_vars ();
- print_fields ();
- F.pp_print_string fmt ".";
-
- (* Close the box *)
- F.pp_close_box fmt ()
- in
-
- (* Generate the [Arguments] instruction *)
- match decl.kind with
- | Opaque -> ()
- | Struct fields ->
- let adt_id = AdtId decl.def_id in
- (* Generate the instruction for the record constructor *)
- let with_opaque_pre = false in
- let cons_name = ctx_get_struct with_opaque_pre adt_id ctx in
- extract_arguments_info cons_name fields;
- (* Generate the instruction for the record projectors, if there are *)
- let is_rec = decl_is_from_rec_group kind in
- if not is_rec then
- FieldId.iteri
- (fun fid _ ->
- let cons_name = ctx_get_field adt_id fid ctx in
- extract_arguments_info cons_name [])
- fields;
- (* Add breaks to insert new lines between definitions *)
- F.pp_print_break fmt 0 0
- | Enum variants ->
- (* Generate the instructions *)
- VariantId.iteri
- (fun vid (v : variant) ->
- let cons_name = ctx_get_variant (AdtId decl.def_id) vid ctx in
- extract_arguments_info cons_name v.fields)
- variants;
- (* Add breaks to insert new lines between definitions *)
- F.pp_print_break fmt 0 0
-
-(** Auxiliary function.
-
- Generate field projectors in Coq.
-
- Sometimes we extract records as inductives in Coq: when this happens we
- have to define the field projectors afterwards.
- *)
-let extract_type_decl_record_field_projectors (ctx : extraction_ctx)
- (fmt : F.formatter) (kind : decl_kind) (decl : type_decl) : unit =
- assert (!backend = Coq);
- match decl.kind with
- | Opaque | Enum _ -> ()
- | Struct fields ->
- (* Records are extracted as inductives only if they are recursive *)
- let is_rec = decl_is_from_rec_group kind in
- if is_rec then
- (* Add the type params *)
- let ctx, type_params, cg_params =
- ctx_add_type_const_generic_params decl.type_params
- decl.const_generic_params ctx
- in
- let ctx, record_var = ctx_add_var "x" (VarId.of_int 0) ctx in
- let ctx, field_var = ctx_add_var "x" (VarId.of_int 1) ctx in
- let with_opaque_pre = false in
- let def_name = ctx_get_local_type with_opaque_pre decl.def_id ctx in
- let cons_name =
- ctx_get_struct with_opaque_pre (AdtId decl.def_id) ctx
- in
- let extract_field_proj (field_id : FieldId.id) (_ : field) : unit =
- F.pp_print_space fmt ();
- (* Outer box for the projector definition *)
- F.pp_open_hvbox fmt 0;
- (* Inner box for the projector definition *)
- F.pp_open_hvbox fmt ctx.indent_incr;
- (* Open a box for the [Definition PROJ ... :=] *)
- F.pp_open_hovbox fmt ctx.indent_incr;
- F.pp_print_string fmt "Definition";
- F.pp_print_space fmt ();
- let field_name = ctx_get_field (AdtId decl.def_id) field_id ctx in
- F.pp_print_string fmt field_name;
- F.pp_print_space fmt ();
- (* Print the type parameters *)
- if type_params <> [] then (
- F.pp_print_string fmt "{";
- List.iter
- (fun p ->
- F.pp_print_string fmt p;
- F.pp_print_space fmt ())
- type_params;
- F.pp_print_string fmt ":";
- F.pp_print_space fmt ();
- F.pp_print_string fmt "Type}";
- F.pp_print_space fmt ());
- (* Print the const generic parameters *)
- if cg_params <> [] then
- List.iter
- (fun (v : const_generic_var) ->
- F.pp_print_string fmt "{";
- let n = ctx_get_const_generic_var v.index ctx in
- F.pp_print_string fmt n;
- F.pp_print_string fmt ":";
- F.pp_print_space fmt ();
- extract_literal_type ctx fmt v.ty;
- F.pp_print_string fmt "}";
- F.pp_print_space fmt ())
- decl.const_generic_params;
- (* Print the record parameter *)
- F.pp_print_string fmt "(";
- F.pp_print_string fmt record_var;
- F.pp_print_space fmt ();
- F.pp_print_string fmt ":";
- F.pp_print_space fmt ();
- F.pp_print_string fmt def_name;
- List.iter
- (fun p ->
- F.pp_print_space fmt ();
- F.pp_print_string fmt p)
- type_params;
- F.pp_print_string fmt ")";
- (* *)
- F.pp_print_space fmt ();
- F.pp_print_string fmt ":=";
- (* Close the box for the [Definition PROJ ... :=] *)
- F.pp_close_box fmt ();
- F.pp_print_space fmt ();
- (* Open a box for the whole match *)
- F.pp_open_hvbox fmt 0;
- (* Open a box for the [match ... with] *)
- F.pp_open_hovbox fmt ctx.indent_incr;
- F.pp_print_string fmt "match";
- F.pp_print_space fmt ();
- F.pp_print_string fmt record_var;
- F.pp_print_space fmt ();
- F.pp_print_string fmt "with";
- (* Close the box for the [match ... with] *)
- F.pp_close_box fmt ();
-
- (* Open a box for the branch *)
- F.pp_open_hovbox fmt ctx.indent_incr;
- (* Print the match branch *)
- F.pp_print_space fmt ();
- F.pp_print_string fmt "|";
- F.pp_print_space fmt ();
- F.pp_print_string fmt cons_name;
- FieldId.iteri
- (fun id _ ->
- F.pp_print_space fmt ();
- if field_id = id then F.pp_print_string fmt field_var
- else F.pp_print_string fmt "_")
- fields;
- F.pp_print_space fmt ();
- F.pp_print_string fmt "=>";
- F.pp_print_space fmt ();
- F.pp_print_string fmt field_var;
- (* Close the box for the branch *)
- F.pp_close_box fmt ();
- (* Print the [end] *)
- F.pp_print_space fmt ();
- F.pp_print_string fmt "end";
- (* Close the box for the whole match *)
- F.pp_close_box fmt ();
- (* Close the inner box projector *)
- F.pp_close_box fmt ();
- (* If Coq: end the definition with a "." *)
- if !backend = Coq then (
- F.pp_print_cut fmt ();
- F.pp_print_string fmt ".");
- (* Close the outer box projector *)
- F.pp_close_box fmt ();
- (* Add breaks to insert new lines between definitions *)
- F.pp_print_break fmt 0 0
- in
-
- let extract_proj_notation (field_id : FieldId.id) (_ : field) : unit =
- F.pp_print_space fmt ();
- (* Outer box for the projector definition *)
- F.pp_open_hvbox fmt 0;
- (* Inner box for the projector definition *)
- F.pp_open_hovbox fmt ctx.indent_incr;
- let ctx, record_var = ctx_add_var "x" (VarId.of_int 0) ctx in
- F.pp_print_string fmt "Notation";
- F.pp_print_space fmt ();
- let field_name = ctx_get_field (AdtId decl.def_id) field_id ctx in
- F.pp_print_string fmt ("\"" ^ record_var ^ " .(" ^ field_name ^ ")\"");
- F.pp_print_space fmt ();
- F.pp_print_string fmt ":=";
- F.pp_print_space fmt ();
- F.pp_print_string fmt "(";
- F.pp_print_string fmt field_name;
- F.pp_print_space fmt ();
- F.pp_print_string fmt record_var;
- F.pp_print_string fmt ")";
- F.pp_print_space fmt ();
- F.pp_print_string fmt "(at level 9)";
- (* Close the inner box projector *)
- F.pp_close_box fmt ();
- (* If Coq: end the definition with a "." *)
- if !backend = Coq then (
- F.pp_print_cut fmt ();
- F.pp_print_string fmt ".");
- (* Close the outer box projector *)
- F.pp_close_box fmt ();
- (* Add breaks to insert new lines between definitions *)
- F.pp_print_break fmt 0 0
- in
-
- let extract_field_proj_and_notation (field_id : FieldId.id)
- (field : field) : unit =
- extract_field_proj field_id field;
- extract_proj_notation field_id field
- in
-
- FieldId.iteri extract_field_proj_and_notation fields
-
-(** Extract extra information for a type (e.g., [Arguments] instructions in Coq).
-
- Note that all the names used for extraction should already have been
- registered.
- *)
-let extract_type_decl_extra_info (ctx : extraction_ctx) (fmt : F.formatter)
- (kind : decl_kind) (decl : type_decl) : unit =
- match !backend with
- | FStar | Lean | HOL4 -> ()
- | Coq ->
- extract_type_decl_coq_arguments ctx fmt kind decl;
- extract_type_decl_record_field_projectors ctx fmt kind decl
-
-(** Extract the state type declaration. *)
-let extract_state_type (fmt : F.formatter) (ctx : extraction_ctx)
- (kind : decl_kind) : unit =
- (* Add a break before *)
- F.pp_print_break fmt 0 0;
- (* Print a comment *)
- extract_comment fmt [ "The state type used in the state-error monad" ];
- F.pp_print_break fmt 0 0;
- (* Open a box for the definition, so that whenever possible it gets printed on
- * one line *)
- F.pp_open_hvbox fmt 0;
- (* Retrieve the name *)
- let state_name = ctx_get_assumed_type State ctx in
- (* The syntax for Lean and Coq is almost identical. *)
- let print_axiom () =
- let axiom =
- match !backend with
- | Coq -> "Axiom"
- | Lean -> "axiom"
- | FStar | HOL4 -> raise (Failure "Unexpected")
- in
- F.pp_print_string fmt axiom;
- F.pp_print_space fmt ();
- F.pp_print_string fmt state_name;
- F.pp_print_space fmt ();
- F.pp_print_string fmt ":";
- F.pp_print_space fmt ();
- F.pp_print_string fmt "Type";
- if !backend = Coq then F.pp_print_string fmt "."
- in
- (* The kind should be [Assumed] or [Declared] *)
- (match kind with
- | SingleNonRec | SingleRec | MutRecFirst | MutRecInner | MutRecLast ->
- raise (Failure "Unexpected")
- | Assumed -> (
- match !backend with
- | FStar ->
- F.pp_print_string fmt "assume";
- F.pp_print_space fmt ();
- F.pp_print_string fmt "type";
- F.pp_print_space fmt ();
- F.pp_print_string fmt state_name;
- F.pp_print_space fmt ();
- F.pp_print_string fmt ":";
- F.pp_print_space fmt ();
- F.pp_print_string fmt "Type0"
- | HOL4 ->
- F.pp_print_string fmt ("val _ = new_type (\"" ^ state_name ^ "\", 0)")
- | Coq | Lean -> print_axiom ())
- | Declared -> (
- match !backend with
- | FStar ->
- F.pp_print_string fmt "val";
- F.pp_print_space fmt ();
- F.pp_print_string fmt state_name;
- F.pp_print_space fmt ();
- F.pp_print_string fmt ":";
- F.pp_print_space fmt ();
- F.pp_print_string fmt "Type0"
- | HOL4 ->
- F.pp_print_string fmt ("val _ = new_type (\"" ^ state_name ^ "\", 0)")
- | Coq | Lean -> print_axiom ()));
- (* Close the box for the definition *)
- F.pp_close_box fmt ();
- (* Add breaks to insert new lines between definitions *)
- F.pp_print_break fmt 0 0
+include ExtractTypes
(** Compute the names for all the pure functions generated from a rust function
(forward function and backward functions).
*)
-let extract_fun_decl_register_names (ctx : extraction_ctx) (keep_fwd : bool)
+let extract_fun_decl_register_names (ctx : extraction_ctx)
(has_decreases_clause : fun_decl -> bool) (def : pure_fun_translation) :
extraction_ctx =
- let (fwd, loop_fwds), back_ls = def in
- (* Register the decrease clauses, if necessary *)
- let register_decreases ctx def =
- if has_decreases_clause def then
- (* Add the termination measure *)
- let ctx = ctx_add_termination_measure def ctx in
- (* Add the decreases proof for Lean only *)
- match !Config.backend with
- | Coq | FStar -> ctx
- | HOL4 -> raise (Failure "Unexpected")
- | Lean -> ctx_add_decreases_proof def ctx
- else ctx
- in
- let ctx = List.fold_left register_decreases ctx (fwd :: loop_fwds) in
- let register_fun ctx f = ctx_add_fun_decl (keep_fwd, def) f ctx in
- let register_funs ctx fl = List.fold_left register_fun ctx fl in
- (* Register the forward functions' names *)
- let ctx = register_funs ctx (fwd :: loop_fwds) in
- (* Register the backward functions' names *)
- let ctx =
- List.fold_left
- (fun ctx (back, loop_backs) ->
- let ctx = register_fun ctx back in
- register_funs ctx loop_backs)
- ctx back_ls
- in
-
- (* Return *)
- ctx
+ (* Ignore the trait methods **declarations** (rem.: we do not ignore the trait
+ method implementations): we do not need to refer to them directly. We will
+ only use their type for the fields of the records we generate for the trait
+ declarations *)
+ match def.fwd.f.kind with
+ | TraitMethodDecl _ -> ctx
+ | _ -> (
+ (* Check if the function is builtin *)
+ let builtin =
+ let open ExtractBuiltin in
+ let funs_map = builtin_funs_map () in
+ let sname = name_to_simple_name def.fwd.f.basename in
+ SimpleNameMap.find_opt sname funs_map
+ in
+ (* Use the builtin names if necessary *)
+ match builtin with
+ | Some (filter_info, info) ->
+ (* Register the filtering information, if there is *)
+ let ctx =
+ match filter_info with
+ | Some keep ->
+ {
+ ctx with
+ funs_filter_type_args_map =
+ FunDeclId.Map.add def.fwd.f.def_id keep
+ ctx.funs_filter_type_args_map;
+ }
+ | _ -> ctx
+ in
+ let backs = List.map (fun f -> f.f) def.backs in
+ let funs = if def.keep_fwd then def.fwd.f :: backs else backs in
+ List.fold_left
+ (fun ctx (f : fun_decl) ->
+ let open ExtractBuiltin in
+ let fun_id =
+ (Pure.FunId (Regular f.def_id), f.loop_id, f.back_id)
+ in
+ let fun_info =
+ List.find_opt
+ (fun (x : builtin_fun_info) -> x.rg = f.back_id)
+ info
+ in
+ match fun_info with
+ | Some fun_info ->
+ ctx_add (FunId (FromLlbc fun_id)) fun_info.extract_name ctx
+ | None ->
+ raise
+ (Failure
+ ("Not found: "
+ ^ Names.name_to_string f.basename
+ ^ ", "
+ ^ Print.option_to_string Pure.show_loop_id f.loop_id
+ ^ Print.option_to_string Pure.show_region_group_id
+ f.back_id)))
+ ctx funs
+ | None ->
+ let fwd = def.fwd in
+ let backs = def.backs in
+ (* Register the decrease clauses, if necessary *)
+ let register_decreases ctx def =
+ if has_decreases_clause def then
+ (* Add the termination measure *)
+ let ctx = ctx_add_termination_measure def ctx in
+ (* Add the decreases proof for Lean only *)
+ match !Config.backend with
+ | Coq | FStar -> ctx
+ | HOL4 -> raise (Failure "Unexpected")
+ | Lean -> ctx_add_decreases_proof def ctx
+ else ctx
+ in
+ let ctx =
+ List.fold_left register_decreases ctx (fwd.f :: fwd.loops)
+ in
+ let register_fun ctx f = ctx_add_fun_decl def f ctx in
+ let register_funs ctx fl = List.fold_left register_fun ctx fl in
+ (* Register the names of the forward functions *)
+ let ctx =
+ if def.keep_fwd then register_funs ctx (fwd.f :: fwd.loops) else ctx
+ in
+ (* Register the names of the backward functions *)
+ List.fold_left
+ (fun ctx { f = back; loops = loop_backs } ->
+ let ctx = register_fun ctx back in
+ register_funs ctx loop_backs)
+ ctx backs)
(** Simply add the global name to the context. *)
let extract_global_decl_register_names (ctx : extraction_ctx)
@@ -2122,11 +124,11 @@ let extract_adt_g_value
(inside : bool) (variant_id : VariantId.id option) (field_values : 'v list)
(ty : ty) : extraction_ctx =
match ty with
- | Adt (Tuple, type_args, cg_args) ->
+ | Adt (Tuple, generics) ->
(* Tuple *)
(* For now, we only support fully applied tuple constructors *)
- assert (List.length type_args = List.length field_values);
- assert (cg_args = []);
+ assert (List.length generics.types = List.length field_values);
+ assert (generics.const_generics = [] && generics.trait_refs = []);
(* This is very annoying: in Coq, we can't write [()] for the value of
type [unit], we have to write [tt]. *)
if !backend = Coq && field_values = [] then (
@@ -2144,7 +146,7 @@ let extract_adt_g_value
in
F.pp_print_string fmt ")";
ctx)
- | Adt (adt_id, _, _) ->
+ | Adt (adt_id, _) ->
(* "Regular" ADT *)
(* If we are generating a pattern for a let-binding and we target Lean,
@@ -2172,18 +174,14 @@ let extract_adt_g_value
* [{ field0=...; ...; fieldn=...; }] in case of structures.
*)
let cons =
- (* The ADT shouldn't be opaque *)
- let with_opaque_pre = false in
match variant_id with
| Some vid -> (
(* In the case of Lean, we might have to add the type name as a prefix *)
match (!backend, adt_id) with
| Lean, Assumed _ ->
- ctx_get_type with_opaque_pre adt_id ctx
- ^ "."
- ^ ctx_get_variant adt_id vid ctx
+ ctx_get_type adt_id ctx ^ "." ^ ctx_get_variant adt_id vid ctx
| _ -> ctx_get_variant adt_id vid ctx)
- | None -> ctx_get_struct with_opaque_pre adt_id ctx
+ | None -> ctx_get_struct adt_id ctx
in
let use_parentheses = inside && field_values <> [] in
if use_parentheses then F.pp_print_string fmt "(";
@@ -2202,8 +200,33 @@ let extract_adt_g_value
(* Extract globals in the same way as variables *)
let extract_global (ctx : extraction_ctx) (fmt : F.formatter)
(id : A.GlobalDeclId.id) : unit =
- let with_opaque_pre = ctx.use_opaque_pre in
- F.pp_print_string fmt (ctx_get_global with_opaque_pre id ctx)
+ F.pp_print_string fmt (ctx_get_global id ctx)
+
+(* Filter the generics of a function if it is builtin *)
+let fun_builtin_filter_types (id : FunDeclId.id) (types : 'a list)
+ (ctx : extraction_ctx) : ('a list, 'a list * string) Result.result =
+ match FunDeclId.Map.find_opt id ctx.funs_filter_type_args_map with
+ | None -> Result.Ok types
+ | Some filter ->
+ if List.length filter <> List.length types then (
+ let decl = FunDeclId.Map.find id ctx.trans_funs in
+ let err =
+ "Ill-formed builtin information for function "
+ ^ Names.name_to_string decl.fwd.f.basename
+ ^ ": "
+ ^ string_of_int (List.length filter)
+ ^ " filtering arguments provided for "
+ ^ string_of_int (List.length types)
+ ^ " type arguments"
+ in
+ log#serror err;
+ Result.Error (types, err))
+ else
+ let types = List.combine filter types in
+ let types =
+ List.filter_map (fun (b, ty) -> if b then Some ty else None) types
+ in
+ Result.Ok types
(** [inside]: see {!extract_ty}.
@@ -2218,7 +241,7 @@ let rec extract_typed_pattern (ctx : extraction_ctx) (fmt : F.formatter)
ctx
| PatVar (v, _) ->
let vname =
- ctx.fmt.var_basename ctx.names_map.names_set v.basename v.ty
+ ctx.fmt.var_basename ctx.names_maps.names_map.names_set v.basename v.ty
in
let ctx, vname = ctx_add_var vname v.id ctx in
F.pp_print_string fmt vname;
@@ -2249,6 +272,9 @@ let rec extract_texpression (ctx : extraction_ctx) (fmt : F.formatter)
| Var var_id ->
let var_name = ctx_get_var var_id ctx in
F.pp_print_string fmt var_name
+ | CVar var_id ->
+ let var_name = ctx_get_const_generic_var var_id ctx in
+ F.pp_print_string fmt var_name
| Const cv -> ctx.fmt.extract_literal fmt inside cv
| App _ ->
let app, args = destruct_apps e in
@@ -2279,14 +305,26 @@ and extract_App (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool)
(* Top-level qualifier *)
match qualif.id with
| FunOrOp fun_id ->
- extract_function_call ctx fmt inside fun_id qualif.type_args
- qualif.const_generic_args args
+ extract_function_call ctx fmt inside fun_id qualif.generics args
| Global global_id -> extract_global ctx fmt global_id
| AdtCons adt_cons_id ->
- extract_adt_cons ctx fmt inside adt_cons_id qualif.type_args
- qualif.const_generic_args args
+ extract_adt_cons ctx fmt inside adt_cons_id qualif.generics args
| Proj proj ->
- extract_field_projector ctx fmt inside app proj qualif.type_args args)
+ extract_field_projector ctx fmt inside app proj qualif.generics args
+ | TraitConst (trait_ref, generics, const_name) ->
+ let use_brackets = generics <> empty_generic_args in
+ if use_brackets then F.pp_print_string fmt "(";
+ extract_trait_ref ctx fmt TypeDeclId.Set.empty false trait_ref;
+ extract_generic_args ctx fmt TypeDeclId.Set.empty generics;
+ let name =
+ ctx_get_trait_const trait_ref.trait_decl_ref.trait_decl_id
+ const_name ctx
+ in
+ let add_brackets (s : string) =
+ if !backend = Coq then "(" ^ s ^ ")" else s
+ in
+ if use_brackets then F.pp_print_string fmt ")";
+ F.pp_print_string fmt ("." ^ add_brackets name))
| _ ->
(* "Regular" expression *)
(* Open parentheses *)
@@ -2309,8 +347,8 @@ and extract_App (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool)
(** Subcase of the app case: function call *)
and extract_function_call (ctx : extraction_ctx) (fmt : F.formatter)
- (inside : bool) (fid : fun_or_op_id) (type_args : ty list)
- (cg_args : const_generic list) (args : texpression list) : unit =
+ (inside : bool) (fid : fun_or_op_id) (generics : generic_args)
+ (args : texpression list) : unit =
match (fid, args) with
| Unop unop, [ arg ] ->
(* A unop can have *at most* one argument (the result can't be a function!).
@@ -2327,24 +365,124 @@ and extract_function_call (ctx : extraction_ctx) (fmt : F.formatter)
if inside then F.pp_print_string fmt "(";
(* Open a box for the function call *)
F.pp_open_hovbox fmt ctx.indent_incr;
- (* Print the function name *)
- let with_opaque_pre = ctx.use_opaque_pre in
- let fun_name = ctx_get_function with_opaque_pre fun_id ctx in
- F.pp_print_string fmt fun_name;
- (* Sanity check: HOL4 doesn't support const generics *)
- assert (cg_args = [] || !backend <> HOL4);
- (* Print the type parameters, if the backend is not HOL4 *)
- if !backend <> HOL4 then (
- List.iter
- (fun ty ->
- F.pp_print_space fmt ();
- extract_ty ctx fmt TypeDeclId.Set.empty true ty)
- type_args;
- List.iter
- (fun cg ->
+ (* Print the function name.
+
+ For the function name: the id is not the same depending on whether
+ we call a trait method and a "regular" function (remark: trait
+ method *implementations* are considered as regular functions here;
+ only calls to method of traits which are parameterized in a where
+ clause have a special treatment.
+
+ Remark: the reason why trait method declarations have a special
+ treatment is that, as traits are extracted to records, we may
+ allow collisions between trait item names and some other names,
+ while we do not allow collisions between function names.
+
+ # Impl trait refs:
+ ==================
+ When the trait ref refers to an impl, in
+ [InterpreterStatement.eval_transparent_function_call_symbolic] we
+ replace the call to the trait impl method to a call to the function
+ which implements the trait method (that is, we "forget" that we
+ called a trait method, and treat it as a regular function call).
+
+ # Provided trait methods:
+ =========================
+ Calls to provided trait methods also have a special treatment.
+ For now, we do not allow overriding provided trait methods (methods
+ for which a default implementation is provided in the trait declaration).
+ Whenever we translate a provided trait method, we translate it once as
+ a function which takes a trait ref as input. We have to handle this
+ case below.
+
+ With an example, if in Rust we write:
+ {[
+ fn Foo {
+ fn f(&self) -> u32; // Required
+ fn ret_true(&self) -> bool { true } // Provided
+ }
+ ]}
+
+ We generate:
+ {[
+ structure Foo (Self : Type) = {
+ f : Self -> result u32
+ }
+
+ let ret_true (Self : Type) (self_clause : Foo Self) (self : Self) : result bool =
+ true
+ ]}
+ *)
+ (match fun_id with
+ | FromLlbc
+ (TraitMethod (trait_ref, method_name, _fun_decl_id), lp_id, rg_id) ->
+ (* We have to check whether the trait method is required or provided *)
+ let trait_decl_id = trait_ref.trait_decl_ref.trait_decl_id in
+ let trait_decl =
+ TraitDeclId.Map.find trait_decl_id ctx.trans_trait_decls
+ in
+ let method_id =
+ PureUtils.trait_decl_get_method trait_decl method_name
+ in
+
+ if not method_id.is_provided then (
+ (* Required method *)
+ assert (lp_id = None);
+ extract_trait_ref ctx fmt TypeDeclId.Set.empty true trait_ref;
+ let fun_name =
+ ctx_get_trait_method trait_ref.trait_decl_ref.trait_decl_id
+ method_name rg_id ctx
+ in
+ let add_brackets (s : string) =
+ if !backend = Coq then "(" ^ s ^ ")" else s
+ in
+ F.pp_print_string fmt ("." ^ add_brackets fun_name))
+ else
+ (* Provided method: we see it as a regular function call, and use
+ the function name *)
+ let fun_id =
+ FromLlbc (FunId (Regular method_id.id), lp_id, rg_id)
+ in
+ let fun_name = ctx_get_function fun_id ctx in
+ F.pp_print_string fmt fun_name;
+
+ (* Note that we do not need to print the generics for the trait
+ declaration: they are always implicit as they can be deduced
+ from the trait self clause.
+
+ Print the trait ref (to instantate the self clause) *)
F.pp_print_space fmt ();
- extract_const_generic ctx fmt true cg)
- cg_args);
+ extract_trait_ref ctx fmt TypeDeclId.Set.empty true trait_ref
+ | _ ->
+ let fun_name = ctx_get_function fun_id ctx in
+ F.pp_print_string fmt fun_name);
+
+ (* Sanity check: HOL4 doesn't support const generics *)
+ assert (generics.const_generics = [] || !backend <> HOL4);
+ (* Print the generics.
+
+ We might need to filter some of the type arguments, if the type
+ is builtin (for instance, we filter the global allocator type
+ argument for `Vec::new`).
+ *)
+ let types =
+ match fun_id with
+ | FromLlbc (FunId (Regular id), _, _) ->
+ fun_builtin_filter_types id generics.types ctx
+ | _ -> Result.Ok generics.types
+ in
+ (match types with
+ | Ok types ->
+ extract_generic_args ctx fmt TypeDeclId.Set.empty
+ { generics with types }
+ | Error (types, err) ->
+ extract_generic_args ctx fmt TypeDeclId.Set.empty
+ { generics with types };
+ if !Config.fail_hard then raise (Failure err)
+ else
+ F.pp_print_string fmt
+ "(\"ERROR: ill-formed builtin: invalid number of filtering \
+ arguments\")");
(* Print the arguments *)
List.iter
(fun ve ->
@@ -2366,9 +504,9 @@ and extract_function_call (ctx : extraction_ctx) (fmt : F.formatter)
(** Subcase of the app case: ADT constructor *)
and extract_adt_cons (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool)
- (adt_cons : adt_cons_id) (type_args : ty list)
- (cg_args : const_generic list) (args : texpression list) : unit =
- let e_ty = Adt (adt_cons.adt_id, type_args, cg_args) in
+ (adt_cons : adt_cons_id) (generics : generic_args) (args : texpression list)
+ : unit =
+ let e_ty = Adt (adt_cons.adt_id, generics) in
let is_single_pat = false in
let _ =
extract_adt_g_value
@@ -2382,7 +520,7 @@ and extract_adt_cons (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool)
(** Subcase of the app case: ADT field projector. *)
and extract_field_projector (ctx : extraction_ctx) (fmt : F.formatter)
(inside : bool) (original_app : texpression) (proj : projection)
- (_proj_type_params : ty list) (args : texpression list) : unit =
+ (_generics : generic_args) (args : texpression list) : unit =
(* We isolate the first argument (if there is), in order to pretty print the
* projection ([x.field] instead of [MkAdt?.field x] *)
match args with
@@ -2734,9 +872,7 @@ and extract_StructUpdate (ctx : extraction_ctx) (fmt : F.formatter)
let extract_as_unit =
match (!backend, supd.struct_id) with
| HOL4, AdtId adt_id ->
- let d =
- TypeDeclId.Map.find adt_id ctx.trans_ctx.type_context.type_decls
- in
+ let d = TypeDeclId.Map.find adt_id ctx.trans_ctx.type_ctx.type_decls in
d.kind = Struct []
| _ -> false
in
@@ -2835,17 +971,17 @@ and extract_StructUpdate (ctx : extraction_ctx) (fmt : F.formatter)
F.pp_open_hvbox fmt ctx.indent_incr;
let need_paren = inside in
if need_paren then F.pp_print_string fmt "(";
- (* Open the box for `Array.mk T N [` *)
+ (* Open the box for `Array.replicate T N [` *)
F.pp_open_hovbox fmt ctx.indent_incr;
(* Print the array constructor *)
- let cs = ctx_get_struct false (Assumed Array) ctx in
+ let cs = ctx_get_struct (Assumed Array) ctx in
F.pp_print_string fmt cs;
(* Print the parameters *)
- let _, tys, cgs = ty_as_adt e_ty in
- let ty = Collections.List.to_cons_nil tys in
+ let _, generics = ty_as_adt e_ty in
+ let ty = Collections.List.to_cons_nil generics.types in
F.pp_print_space fmt ();
extract_ty ctx fmt TypeDeclId.Set.empty true ty;
- let cg = Collections.List.to_cons_nil cgs in
+ let cg = Collections.List.to_cons_nil generics.const_generics in
F.pp_print_space fmt ();
extract_const_generic ctx fmt true cg;
F.pp_print_space fmt ();
@@ -2872,17 +1008,15 @@ and extract_StructUpdate (ctx : extraction_ctx) (fmt : F.formatter)
F.pp_close_box fmt ()
| _ -> raise (Failure "Unreachable")
-(** Insert a space, if necessary *)
-let insert_req_space (fmt : F.formatter) (space : bool ref) : unit =
- if !space then space := false else F.pp_print_space fmt ()
-
(** A small utility to print the parameters of a function signature.
We return two contexts:
- - the context augmented with bindings for the type parameters
- - the context augmented with bindings for the type parameters *and*
+ - the context augmented with bindings for the generics
+ - the context augmented with bindings for the generics *and*
bindings for the input values
+ We also return names for the type parameters, const generics, etc.
+
TODO: do we really need the first one? We should probably always use
the second one.
It comes from the fact that when we print the input values for the
@@ -2890,57 +1024,40 @@ let insert_req_space (fmt : F.formatter) (space : bool ref) : unit =
patterns, not the variables). We should figure a cleaner way.
*)
let extract_fun_parameters (space : bool ref) (ctx : extraction_ctx)
- (fmt : F.formatter) (def : fun_decl) : extraction_ctx * extraction_ctx =
+ (fmt : F.formatter) (def : fun_decl) :
+ extraction_ctx * extraction_ctx * string list =
+ (* First, add the associated types and constants if the function is a method
+ in a trait declaration.
+
+ About the order: we want to make sure the names are reserved for
+ those (variable names might collide with them but it is ok, we will add
+ suffixes to the variables).
+
+ TODO: micro-pass to update what happens when calling trait provided
+ functions.
+ *)
+ let ctx, trait_decl =
+ match def.kind with
+ | TraitMethodProvided (decl_id, _) ->
+ let trait_decl = T.TraitDeclId.Map.find decl_id ctx.trans_trait_decls in
+ let ctx, _ = ctx_add_trait_self_clause ctx in
+ let ctx = { ctx with is_provided_method = true } in
+ (ctx, Some trait_decl)
+ | _ -> (ctx, None)
+ in
(* Add the type parameters - note that we need those bindings only for the
* body translation (they are not top-level) *)
- let ctx, type_params, cg_params =
- ctx_add_type_const_generic_params def.signature.type_params
- def.signature.const_generic_params ctx
+ let ctx, type_params, cg_params, trait_clauses =
+ ctx_add_generic_params def.signature.generics ctx
in
- (* Print the parameters - rem.: we should have filtered the functions
- * with no input parameters *)
- (* The type parameters.
-
- Note that in HOL4 we don't print the type parameters.
- *)
- if (type_params <> [] || cg_params <> []) && !backend <> HOL4 then (
- (* Open a box for the type and const generic parameters *)
- F.pp_open_hovbox fmt 0;
- (* The type parameters *)
- if type_params <> [] then (
- insert_req_space fmt space;
- F.pp_print_string fmt "(";
- List.iter
- (fun (p : type_var) ->
- let pname = ctx_get_type_var p.index ctx in
- F.pp_print_string fmt pname;
- F.pp_print_space fmt ())
- def.signature.type_params;
- F.pp_print_string fmt ":";
- F.pp_print_space fmt ();
- let type_keyword =
- match !backend with
- | FStar -> "Type0"
- | Coq | Lean -> "Type"
- | HOL4 -> raise (Failure "Unreachable")
- in
- F.pp_print_string fmt (type_keyword ^ ")"));
- (* The const generic parameters *)
- if cg_params <> [] then
- List.iter
- (fun (p : const_generic_var) ->
- let pname = ctx_get_const_generic_var p.index ctx in
- insert_req_space fmt space;
- F.pp_print_string fmt "(";
- F.pp_print_string fmt pname;
- F.pp_print_space fmt ();
- F.pp_print_string fmt ":";
- F.pp_print_space fmt ();
- extract_literal_type ctx fmt p.ty;
- F.pp_print_string fmt ")")
- def.signature.const_generic_params;
- (* Close the box for the type parameters *)
- F.pp_close_box fmt ());
+ (* Print the generics *)
+ (* Open a box for the generics *)
+ F.pp_open_hovbox fmt 0;
+ (let space = Some space in
+ extract_generic_params ctx fmt TypeDeclId.Set.empty ~space ~trait_decl
+ def.signature.generics type_params cg_params trait_clauses);
+ (* Close the box for the generics *)
+ F.pp_close_box fmt ();
(* The input parameters - note that doing this adds bindings to the context *)
let ctx_body =
match def.body with
@@ -2963,7 +1080,7 @@ let extract_fun_parameters (space : bool ref) (ctx : extraction_ctx)
ctx)
ctx body.inputs_lvs
in
- (ctx, ctx_body)
+ (ctx, ctx_body, List.concat [ type_params; cg_params; trait_clauses ])
(** A small utility to print the types of the input parameters in the form:
[u32 -> list u32 -> ...]
@@ -2982,6 +1099,11 @@ let extract_fun_input_parameters_types (ctx : extraction_ctx)
in
List.iter extract_param def.signature.inputs
+let extract_fun_inputs_output_parameters_types (ctx : extraction_ctx)
+ (fmt : F.formatter) (def : fun_decl) : unit =
+ extract_fun_input_parameters_types ctx fmt def;
+ extract_ty ctx fmt TypeDeclId.Set.empty false def.signature.output
+
let assert_backend_supports_decreases_clauses () =
match !backend with
| FStar | Lean -> ()
@@ -3032,7 +1154,7 @@ let extract_template_fstar_decreases_clause (ctx : extraction_ctx)
F.pp_print_space fmt ();
(* Extract the parameters *)
let space = ref true in
- let _, _ = extract_fun_parameters space ctx fmt def in
+ let _, _, _ = extract_fun_parameters space ctx fmt def in
insert_req_space fmt space;
F.pp_print_string fmt ":";
(* Print the signature *)
@@ -3094,7 +1216,7 @@ let extract_template_lean_termination_and_decreasing (ctx : extraction_ctx)
F.pp_print_space fmt ();
(* Extract the parameters *)
let space = ref true in
- let _, ctx_body = extract_fun_parameters space ctx fmt def in
+ let _, ctx_body, _ = extract_fun_parameters space ctx fmt def in
(* Print the ":=" *)
F.pp_print_space fmt ();
F.pp_print_string fmt ":=";
@@ -3164,7 +1286,7 @@ let extract_fun_comment (ctx : extraction_ctx) (fmt : F.formatter)
(def : fun_decl) : unit =
let { keep_fwd; num_backs } =
PureUtils.RegularFunIdMap.find
- (A.Regular def.def_id, def.loop_id, def.back_id)
+ (Pure.FunId (Regular def.def_id), def.loop_id, def.back_id)
ctx.fun_name_info
in
let comment_pre = "[" ^ Print.fun_name_to_string def.basename ^ "]: " in
@@ -3205,10 +1327,8 @@ let extract_fun_decl_gen (ctx : extraction_ctx) (fmt : F.formatter)
(kind : decl_kind) (has_decreases_clause : bool) (def : fun_decl) : unit =
assert (not def.is_global_decl_body);
(* Retrieve the function name *)
- let with_opaque_pre = false in
let def_name =
- ctx_get_local_function with_opaque_pre def.def_id def.loop_id def.back_id
- ctx
+ ctx_get_local_function def.def_id def.loop_id def.back_id ctx
in
(* Add a break before *)
if !backend <> HOL4 || not (decl_is_first_from_group kind) then
@@ -3234,23 +1354,15 @@ let extract_fun_decl_gen (ctx : extraction_ctx) (fmt : F.formatter)
*)
let is_opaque_coq = !backend = Coq && is_opaque in
let use_forall =
- is_opaque_coq
- && (def.signature.type_params <> []
- || def.signature.const_generic_params <> [])
+ is_opaque_coq && def.signature.generics <> empty_generic_params
in
- (* Print the qualifier ("assume", etc.).
-
- if `wrap_opaque_in_sig`: we generate a record of assumed funcions.
- TODO: this is obsolete.
- *)
- (if not (!Config.wrap_opaque_in_sig && (kind = Assumed || kind = Declared))
- then
- let qualif = ctx.fmt.fun_decl_kind_to_qualif kind in
- match qualif with
- | Some qualif ->
- F.pp_print_string fmt qualif;
- F.pp_print_space fmt ()
- | None -> ());
+ (* Print the qualifier ("assume", etc.). *)
+ let qualif = ctx.fmt.fun_decl_kind_to_qualif kind in
+ (match qualif with
+ | Some qualif ->
+ F.pp_print_string fmt qualif;
+ F.pp_print_space fmt ()
+ | None -> ());
F.pp_print_string fmt def_name;
F.pp_print_space fmt ();
if use_forall then (
@@ -3262,7 +1374,7 @@ let extract_fun_decl_gen (ctx : extraction_ctx) (fmt : F.formatter)
(* Open a box for "(PARAMS) :" *)
F.pp_open_hovbox fmt 0;
let space = ref true in
- let ctx, ctx_body = extract_fun_parameters space ctx fmt def in
+ let ctx, ctx_body, all_params = extract_fun_parameters space ctx fmt def in
(* Print the return type - note that we have to be careful when
* printing the input values for the decrease clause, because
* it introduces bindings in the context... We thus "forget"
@@ -3310,20 +1422,13 @@ let extract_fun_decl_gen (ctx : extraction_ctx) (fmt : F.formatter)
(* The name of the decrease clause *)
let decr_name = ctx_get_termination_measure def.def_id def.loop_id ctx in
F.pp_print_string fmt decr_name;
- (* Print the type/const generic parameters - TODO: we do this many
+ (* Print the generic parameters - TODO: we do this many
times, we should have a helper to factor it out *)
List.iter
- (fun (p : type_var) ->
- let pname = ctx_get_type_var p.index ctx in
+ (fun (name : string) ->
F.pp_print_space fmt ();
- F.pp_print_string fmt pname)
- def.signature.type_params;
- List.iter
- (fun (p : const_generic_var) ->
- let pname = ctx_get_const_generic_var p.index ctx in
- F.pp_print_space fmt ();
- F.pp_print_string fmt pname)
- def.signature.const_generic_params;
+ F.pp_print_string fmt name)
+ all_params;
(* Print the input values: we have to be careful here to print
* only the input values which are in common with the *forward*
* function (the additional input values "given back" to the
@@ -3410,19 +1515,12 @@ let extract_fun_decl_gen (ctx : extraction_ctx) (fmt : F.formatter)
(* Open the box for [DECREASES] *)
F.pp_open_hovbox fmt ctx.indent_incr;
F.pp_print_string fmt terminates_name;
- (* Print the type/const generic params - TODO: factor out *)
+ (* Print the generic params - TODO: factor out *)
List.iter
- (fun (p : type_var) ->
- let pname = ctx_get_type_var p.index ctx in
+ (fun (name : string) ->
F.pp_print_space fmt ();
- F.pp_print_string fmt pname)
- def.signature.type_params;
- List.iter
- (fun (p : const_generic_var) ->
- let pname = ctx_get_const_generic_var p.index ctx in
- F.pp_print_space fmt ();
- F.pp_print_string fmt pname)
- def.signature.const_generic_params;
+ F.pp_print_string fmt name)
+ all_params;
(* Print the variables *)
List.iter
(fun v ->
@@ -3475,18 +1573,13 @@ let extract_fun_decl_gen (ctx : extraction_ctx) (fmt : F.formatter)
let extract_fun_decl_hol4_opaque (ctx : extraction_ctx) (fmt : F.formatter)
(def : fun_decl) : unit =
(* Retrieve the definition name *)
- let with_opaque_pre = false in
let def_name =
- ctx_get_local_function with_opaque_pre def.def_id def.loop_id def.back_id
- ctx
+ ctx_get_local_function def.def_id def.loop_id def.back_id ctx
in
- assert (def.signature.const_generic_params = []);
+ assert (def.signature.generics.const_generics = []);
(* Add the type/const gen parameters - note that we need those bindings
only for the generation of the type (they are not top-level) *)
- let ctx, _, _ =
- ctx_add_type_const_generic_params def.signature.type_params
- def.signature.const_generic_params ctx
- in
+ let ctx, _, _, _ = ctx_add_generic_params def.signature.generics ctx in
(* Add breaks to insert new lines between definitions *)
F.pp_print_break fmt 0 0;
(* Open a box for the whole definition *)
@@ -3635,8 +1728,13 @@ let extract_global_decl_hol4_opaque (ctx : extraction_ctx) (fmt : F.formatter)
(* Print the type *)
F.pp_open_hovbox fmt 0;
extract_ty ctx fmt TypeDeclId.Set.empty false ty;
+ (* Close the definition *)
+ F.pp_print_string fmt ")";
+ F.pp_close_box fmt ();
+ (* Close the definition box *)
F.pp_close_box fmt ();
- (* Close the definition boxe *) F.pp_close_box fmt ()
+ (* Add a line *)
+ F.pp_print_space fmt ()
(** Extract a global declaration.
@@ -3662,21 +1760,19 @@ let extract_global_decl (ctx : extraction_ctx) (fmt : F.formatter)
(global : A.global_decl) (body : fun_decl) (interface : bool) : unit =
assert body.is_global_decl_body;
assert (Option.is_none body.back_id);
- assert (List.length body.signature.inputs = 0);
+ assert (body.signature.inputs = []);
assert (List.length body.signature.doutputs = 1);
- assert (List.length body.signature.type_params = 0);
- assert (List.length body.signature.const_generic_params = 0);
+ assert (body.signature.generics = empty_generic_params);
(* Add a break then the name of the corresponding LLBC declaration *)
F.pp_print_break fmt 0 0;
extract_comment fmt [ "[" ^ Print.global_name_to_string global.name ^ "]" ];
F.pp_print_space fmt ();
- let with_opaque_pre = false in
- let decl_name = ctx_get_global with_opaque_pre global.def_id ctx in
+ let decl_name = ctx_get_global global.def_id ctx in
let body_name =
- ctx_get_function with_opaque_pre
- (FromLlbc (Regular global.body_id, None, None))
+ ctx_get_function
+ (FromLlbc (Pure.FunId (Regular global.body_id), None, None))
ctx
in
@@ -3713,6 +1809,807 @@ let extract_global_decl (ctx : extraction_ctx) (fmt : F.formatter)
(* Add a break to insert lines between declarations *)
F.pp_print_break fmt 0 0
+(** Similar to {!extract_trait_decl_register_names} *)
+let extract_trait_decl_register_parent_clause_names (ctx : extraction_ctx)
+ (trait_decl : trait_decl)
+ (builtin_info : ExtractBuiltin.builtin_trait_decl_info option) :
+ extraction_ctx =
+ (* Compute the clause names *)
+ let clause_names =
+ match builtin_info with
+ | None ->
+ List.map
+ (fun (c : trait_clause) ->
+ let name = ctx.fmt.trait_parent_clause_name trait_decl c in
+ (* Add a prefix if necessary *)
+ let name =
+ if !Config.record_fields_short_names then name
+ else ctx.fmt.trait_decl_name trait_decl ^ name
+ in
+ (c.clause_id, name))
+ trait_decl.parent_clauses
+ | Some info ->
+ List.map
+ (fun (c, name) -> (c.clause_id, name))
+ (List.combine trait_decl.parent_clauses info.parent_clauses)
+ in
+ (* Register the names *)
+ List.fold_left
+ (fun ctx (cid, cname) ->
+ ctx_add (TraitParentClauseId (trait_decl.def_id, cid)) cname ctx)
+ ctx clause_names
+
+(** Similar to {!extract_trait_decl_register_names} *)
+let extract_trait_decl_register_constant_names (ctx : extraction_ctx)
+ (trait_decl : trait_decl)
+ (builtin_info : ExtractBuiltin.builtin_trait_decl_info option) :
+ extraction_ctx =
+ let consts = trait_decl.consts in
+ (* Compute the names *)
+ let constant_names =
+ match builtin_info with
+ | None ->
+ List.map
+ (fun (item_name, _) ->
+ let name = ctx.fmt.trait_const_name trait_decl item_name in
+ (* Add a prefix if necessary *)
+ let name =
+ if !Config.record_fields_short_names then name
+ else ctx.fmt.trait_decl_name trait_decl ^ name
+ in
+ (item_name, name))
+ consts
+ | Some info ->
+ let const_map = StringMap.of_list info.consts in
+ List.map
+ (fun (item_name, _) ->
+ (item_name, StringMap.find item_name const_map))
+ consts
+ in
+ (* Register the names *)
+ List.fold_left
+ (fun ctx (item_name, name) ->
+ ctx_add (TraitItemId (trait_decl.def_id, item_name)) name ctx)
+ ctx constant_names
+
+(** Similar to {!extract_trait_decl_register_names} *)
+let extract_trait_decl_type_names (ctx : extraction_ctx)
+ (trait_decl : trait_decl)
+ (builtin_info : ExtractBuiltin.builtin_trait_decl_info option) :
+ extraction_ctx =
+ let types = trait_decl.types in
+ (* Compute the names *)
+ let type_names =
+ match builtin_info with
+ | None ->
+ let compute_type_name (item_name : string) : string =
+ let type_name = ctx.fmt.trait_type_name trait_decl item_name in
+ if !Config.record_fields_short_names then type_name
+ else ctx.fmt.trait_decl_name trait_decl ^ type_name
+ in
+ let compute_clause_name (item_name : string) (clause : trait_clause) :
+ TraitClauseId.id * string =
+ let name =
+ ctx.fmt.trait_type_clause_name trait_decl item_name clause
+ in
+ (* Add a prefix if necessary *)
+ let name =
+ if !Config.record_fields_short_names then name
+ else ctx.fmt.trait_decl_name trait_decl ^ name
+ in
+ (clause.clause_id, name)
+ in
+ List.map
+ (fun (item_name, (item_clauses, _)) ->
+ (* Type name *)
+ let type_name = compute_type_name item_name in
+ (* Clause names *)
+ let clauses =
+ List.map (compute_clause_name item_name) item_clauses
+ in
+ (item_name, (type_name, clauses)))
+ types
+ | Some info ->
+ let type_map = StringMap.of_list info.types in
+ List.map
+ (fun (item_name, (item_clauses, _)) ->
+ let type_name, clauses_info = StringMap.find item_name type_map in
+ let clauses =
+ List.map
+ (fun (clause, clause_name) -> (clause.clause_id, clause_name))
+ (List.combine item_clauses clauses_info)
+ in
+ (item_name, (type_name, clauses)))
+ types
+ in
+ (* Register the names *)
+ List.fold_left
+ (fun ctx (item_name, (type_name, clauses)) ->
+ let ctx =
+ ctx_add (TraitItemId (trait_decl.def_id, item_name)) type_name ctx
+ in
+ List.fold_left
+ (fun ctx (clause_id, clause_name) ->
+ ctx_add
+ (TraitItemClauseId (trait_decl.def_id, item_name, clause_id))
+ clause_name ctx)
+ ctx clauses)
+ ctx type_names
+
+(** Similar to {!extract_trait_decl_register_names} *)
+let extract_trait_decl_method_names (ctx : extraction_ctx)
+ (trait_decl : trait_decl)
+ (builtin_info : ExtractBuiltin.builtin_trait_decl_info option) :
+ extraction_ctx =
+ let required_methods = trait_decl.required_methods in
+ (* Compute the names *)
+ let method_names =
+ (* We add one field per required forward/backward function *)
+ let get_funs_for_id (id : fun_decl_id) : fun_decl list =
+ let trans : pure_fun_translation = FunDeclId.Map.find id ctx.trans_funs in
+ List.map (fun f -> f.f) (trans.fwd :: trans.backs)
+ in
+ match builtin_info with
+ | None ->
+ (* We add one field per required forward/backward function *)
+ let compute_item_names (item_name : string) (id : fun_decl_id) :
+ string * (RegionGroupId.id option * string) list =
+ let compute_fun_name (f : fun_decl) : RegionGroupId.id option * string
+ =
+ (* We do something special to reuse the [ctx_compute_fun_decl]
+ function. TODO: make it cleaner. *)
+ let basename : name = [ Ident item_name ] in
+ let f = { f with basename } in
+ let trans = A.FunDeclId.Map.find f.def_id ctx.trans_funs in
+ let name = ctx_compute_fun_name trans f ctx in
+ (* Add a prefix if necessary *)
+ let name =
+ if !Config.record_fields_short_names then name
+ else ctx.fmt.trait_decl_name trait_decl ^ "_" ^ name
+ in
+ (f.back_id, name)
+ in
+ let funs = get_funs_for_id id in
+ (item_name, List.map compute_fun_name funs)
+ in
+ List.map (fun (name, id) -> compute_item_names name id) required_methods
+ | Some info ->
+ let funs_map = StringMap.of_list info.methods in
+ List.map
+ (fun (item_name, fun_id) ->
+ let open ExtractBuiltin in
+ let info = StringMap.find item_name funs_map in
+ let trans_funs = get_funs_for_id fun_id in
+ let find (trans_fun : fun_decl) =
+ let info =
+ List.find_opt
+ (fun (info : builtin_fun_info) -> info.rg = trans_fun.back_id)
+ info
+ in
+ match info with
+ | Some info -> (info.rg, info.extract_name)
+ | None ->
+ let err =
+ "Ill-formed builtin information for trait decl \""
+ ^ Names.name_to_string trait_decl.name
+ ^ "\", method \"" ^ item_name
+ ^ "\": could not find name for region "
+ ^ Print.option_to_string Pure.show_region_group_id
+ trans_fun.back_id
+ in
+ log#serror err;
+ if !Config.fail_hard then raise (Failure err)
+ else (trans_fun.back_id, "%ERROR_BUILTIN_NAME_NOT_FOUND%")
+ in
+ let rg_with_name_list = List.map find trans_funs in
+ (item_name, rg_with_name_list))
+ required_methods
+ in
+ (* Register the names *)
+ List.fold_left
+ (fun ctx (item_name, funs) ->
+ (* We add one field per required forward/backward function *)
+ List.fold_left
+ (fun ctx (rg, fun_name) ->
+ ctx_add
+ (TraitMethodId (trait_decl.def_id, item_name, rg))
+ fun_name ctx)
+ ctx funs)
+ ctx method_names
+
+(** Similar to {!extract_type_decl_register_names} *)
+let extract_trait_decl_register_names (ctx : extraction_ctx)
+ (trait_decl : trait_decl) : extraction_ctx =
+ (* Lookup the information if this is a builtin trait *)
+ let open ExtractBuiltin in
+ let sname = name_to_simple_name trait_decl.name in
+ let builtin_info =
+ SimpleNameMap.find_opt sname (builtin_trait_decls_map ())
+ in
+ let ctx =
+ let trait_name, trait_constructor =
+ match builtin_info with
+ | None ->
+ ( ctx.fmt.trait_decl_name trait_decl,
+ ctx.fmt.trait_decl_constructor trait_decl )
+ | Some info -> (info.extract_name, info.constructor)
+ in
+ let ctx = ctx_add (TraitDeclId trait_decl.def_id) trait_name ctx in
+ ctx_add (TraitDeclConstructorId trait_decl.def_id) trait_constructor ctx
+ in
+ (* Parent clauses *)
+ let ctx =
+ extract_trait_decl_register_parent_clause_names ctx trait_decl builtin_info
+ in
+ (* Constants *)
+ let ctx =
+ extract_trait_decl_register_constant_names ctx trait_decl builtin_info
+ in
+ (* Types *)
+ let ctx = extract_trait_decl_type_names ctx trait_decl builtin_info in
+ (* Required methods *)
+ let ctx = extract_trait_decl_method_names ctx trait_decl builtin_info in
+ ctx
+
+(** Similar to {!extract_type_decl_register_names} *)
+let extract_trait_impl_register_names (ctx : extraction_ctx)
+ (trait_impl : trait_impl) : extraction_ctx =
+ let decl_id = trait_impl.impl_trait.trait_decl_id in
+ let trait_decl = TraitDeclId.Map.find decl_id ctx.trans_trait_decls in
+ (* Check if the trait implementation is builtin *)
+ let builtin_info =
+ let open ExtractBuiltin in
+ let type_sname = name_to_simple_name trait_impl.name in
+ let trait_sname = name_to_simple_name trait_decl.name in
+ SimpleNamePairMap.find_opt (type_sname, trait_sname)
+ (builtin_trait_impls_map ())
+ in
+ (* Register some builtin information (if necessary) *)
+ let ctx, builtin_info =
+ match builtin_info with
+ | None -> (ctx, None)
+ | Some (filter, info) ->
+ let ctx =
+ match filter with
+ | None -> ctx
+ | Some filter ->
+ {
+ ctx with
+ trait_impls_filter_type_args_map =
+ TraitImplId.Map.add trait_impl.def_id filter
+ ctx.trait_impls_filter_type_args_map;
+ }
+ in
+ (ctx, Some info)
+ in
+
+ (* For now we do not support overriding provided methods *)
+ assert (trait_impl.provided_methods = []);
+ (* Everything is taken care of by {!extract_trait_decl_register_names} *but*
+ the name of the implementation itself *)
+ (* Compute the name *)
+ let name =
+ match builtin_info with
+ | None -> ctx.fmt.trait_impl_name trait_decl trait_impl
+ | Some name -> name
+ in
+ ctx_add (TraitImplId trait_impl.def_id) name ctx
+
+(** Small helper.
+
+ The type `ty` is to be understood in a very general sense.
+ *)
+let extract_trait_item (ctx : extraction_ctx) (fmt : F.formatter)
+ (item_name : string) (separator : string) (ty : unit -> unit) : unit =
+ F.pp_print_space fmt ();
+ F.pp_open_hovbox fmt ctx.indent_incr;
+ F.pp_print_string fmt item_name;
+ F.pp_print_space fmt ();
+ (* ":" or "=" *)
+ F.pp_print_string fmt separator;
+ ty ();
+ (match !Config.backend with Lean -> () | _ -> F.pp_print_string fmt ";");
+ F.pp_close_box fmt ()
+
+let extract_trait_decl_item (ctx : extraction_ctx) (fmt : F.formatter)
+ (item_name : string) (ty : unit -> unit) : unit =
+ extract_trait_item ctx fmt item_name ":" ty
+
+let extract_trait_impl_item (ctx : extraction_ctx) (fmt : F.formatter)
+ (item_name : string) (ty : unit -> unit) : unit =
+ let assign = match !Config.backend with Lean | Coq -> ":=" | _ -> "=" in
+ extract_trait_item ctx fmt item_name assign ty
+
+(** Small helper - TODO: move *)
+let generic_params_drop_prefix ~(drop_trait_clauses : bool)
+ (g1 : generic_params) (g2 : generic_params) : generic_params =
+ let open Collections.List in
+ let types = drop (length g1.types) g2.types in
+ let const_generics = drop (length g1.const_generics) g2.const_generics in
+ let trait_clauses =
+ if drop_trait_clauses then drop (length g1.trait_clauses) g2.trait_clauses
+ else g2.trait_clauses
+ in
+ { types; const_generics; trait_clauses }
+
+(** Small helper.
+
+ Extract the items for a method in a trait decl.
+ *)
+let extract_trait_decl_method_items (ctx : extraction_ctx) (fmt : F.formatter)
+ (decl : trait_decl) (item_name : string) (id : fun_decl_id) : unit =
+ (* Lookup the definition *)
+ let trans = A.FunDeclId.Map.find id ctx.trans_funs in
+ (* Extract the items *)
+ let funs = if trans.keep_fwd then trans.fwd :: trans.backs else trans.backs in
+ let extract_method (f : fun_and_loops) =
+ let f = f.f in
+ let fun_name = ctx_get_trait_method decl.def_id item_name f.back_id ctx in
+ let ty () =
+ (* Extract the generics *)
+ (* We need to add the generics specific to the method, by removing those
+ which actually apply to the trait decl *)
+ let generics =
+ let drop_trait_clauses = false in
+ generic_params_drop_prefix ~drop_trait_clauses decl.generics
+ f.signature.generics
+ in
+ let ctx, type_params, cg_params, trait_clauses =
+ ctx_add_generic_params generics ctx
+ in
+ let backend_uses_forall =
+ match !backend with Coq | Lean -> true | FStar | HOL4 -> false
+ in
+ let generics_not_empty = generics <> empty_generic_params in
+ let use_forall = generics_not_empty && backend_uses_forall in
+ let use_arrows = generics_not_empty && not backend_uses_forall in
+ let use_forall_use_sep = false in
+ extract_generic_params ctx fmt TypeDeclId.Set.empty ~use_forall
+ ~use_forall_use_sep ~use_arrows generics type_params cg_params
+ trait_clauses;
+ if use_forall then F.pp_print_string fmt ",";
+ (* Extract the inputs and output *)
+ F.pp_print_space fmt ();
+ extract_fun_inputs_output_parameters_types ctx fmt f
+ in
+ extract_trait_decl_item ctx fmt fun_name ty
+ in
+ List.iter extract_method funs
+
+(** Extract a trait declaration *)
+let extract_trait_decl (ctx : extraction_ctx) (fmt : F.formatter)
+ (decl : trait_decl) : unit =
+ (* Retrieve the trait name *)
+ let decl_name = ctx_get_trait_decl decl.def_id ctx in
+ (* Add a break before *)
+ F.pp_print_break fmt 0 0;
+ (* Print a comment to link the extracted type to its original rust definition *)
+ extract_comment fmt
+ [ "Trait declaration: [" ^ Print.name_to_string decl.name ^ "]" ];
+ F.pp_print_break fmt 0 0;
+ (* Open two outer boxes for the definition, so that whenever possible it gets printed on
+ one line and indents are correct.
+
+ There is just an exception with Lean: in this backend, line breaks are important
+ for the parsing, so we always open a vertical box.
+ *)
+ if !Config.backend = Lean then F.pp_open_vbox fmt ctx.indent_incr
+ else (
+ F.pp_open_hvbox fmt 0;
+ F.pp_open_hvbox fmt ctx.indent_incr);
+
+ (* `struct Trait (....) =` *)
+ (* Open the box for the name + generics *)
+ F.pp_open_hovbox fmt ctx.indent_incr;
+ let qualif =
+ Option.get (ctx.fmt.type_decl_kind_to_qualif SingleNonRec (Some Struct))
+ in
+ (* When checking if the trait declaration is empty: we ignore the provided
+ methods, because for now they are extracted separately *)
+ let is_empty = trait_decl_is_empty { decl with provided_methods = [] } in
+ if !backend = FStar && not is_empty then (
+ F.pp_print_string fmt "noeq";
+ F.pp_print_space fmt ());
+ F.pp_print_string fmt qualif;
+ F.pp_print_space fmt ();
+ F.pp_print_string fmt decl_name;
+ (* Print the generics *)
+ let generics = decl.generics in
+ (* Add the type and const generic params - note that we need those bindings only for the
+ * body translation (they are not top-level) *)
+ let ctx, type_params, cg_params, trait_clauses =
+ ctx_add_generic_params generics ctx
+ in
+ extract_generic_params ctx fmt TypeDeclId.Set.empty generics type_params
+ cg_params trait_clauses;
+
+ F.pp_print_space fmt ();
+ if is_empty && !backend = FStar then (
+ F.pp_print_string fmt "= unit";
+ (* Outer box *)
+ F.pp_close_box fmt ())
+ else if is_empty && !backend = Coq then (
+ (* Coq is not very good at infering constructors *)
+ let cons = ctx_get_trait_constructor decl.def_id ctx in
+ F.pp_print_string fmt (":= " ^ cons ^ "{}.");
+ (* Outer box *)
+ F.pp_close_box fmt ())
+ else (
+ (match !backend with
+ | Lean -> F.pp_print_string fmt "where"
+ | FStar -> F.pp_print_string fmt "= {"
+ | Coq ->
+ let cons = ctx_get_trait_constructor decl.def_id ctx in
+ F.pp_print_string fmt (":= " ^ cons ^ " {")
+ | _ -> F.pp_print_string fmt "{");
+
+ (* Close the box for the name + generics *)
+ F.pp_close_box fmt ();
+
+ (*
+ * Extract the items
+ *)
+
+ (* The constants *)
+ List.iter
+ (fun (name, (ty, _)) ->
+ let item_name = ctx_get_trait_const decl.def_id name ctx in
+ let ty () =
+ let inside = false in
+ F.pp_print_space fmt ();
+ extract_ty ctx fmt TypeDeclId.Set.empty inside ty
+ in
+ extract_trait_decl_item ctx fmt item_name ty)
+ decl.consts;
+
+ (* The types *)
+ List.iter
+ (fun (name, (clauses, _)) ->
+ (* Extract the type *)
+ let item_name = ctx_get_trait_type decl.def_id name ctx in
+ let ty () =
+ F.pp_print_space fmt ();
+ F.pp_print_string fmt (type_keyword ())
+ in
+ extract_trait_decl_item ctx fmt item_name ty;
+ (* Extract the clauses *)
+ List.iter
+ (fun clause ->
+ let item_name =
+ ctx_get_trait_item_clause decl.def_id name clause.clause_id ctx
+ in
+ let ty () =
+ F.pp_print_space fmt ();
+ extract_trait_clause_type ctx fmt TypeDeclId.Set.empty clause
+ in
+ extract_trait_decl_item ctx fmt item_name ty)
+ clauses)
+ decl.types;
+
+ (* The parent clauses - note that the parent clauses may refer to the types
+ and const generics: for this reason we extract them *after* *)
+ List.iter
+ (fun clause ->
+ let item_name =
+ ctx_get_trait_parent_clause decl.def_id clause.clause_id ctx
+ in
+ let ty () =
+ F.pp_print_space fmt ();
+ extract_trait_clause_type ctx fmt TypeDeclId.Set.empty clause
+ in
+ extract_trait_decl_item ctx fmt item_name ty)
+ decl.parent_clauses;
+
+ (* The required methods *)
+ List.iter
+ (fun (name, id) -> extract_trait_decl_method_items ctx fmt decl name id)
+ decl.required_methods;
+
+ (* Close the outer boxes for the definition *)
+ if !Config.backend <> Lean then F.pp_close_box fmt ();
+ (* Close the brackets *)
+ match !Config.backend with
+ | Lean -> ()
+ | Coq ->
+ F.pp_print_space fmt ();
+ F.pp_print_string fmt "}."
+ | _ ->
+ F.pp_print_space fmt ();
+ F.pp_print_string fmt "}");
+ F.pp_close_box fmt ();
+ (* Add breaks to insert new lines between definitions *)
+ F.pp_print_break fmt 0 0
+
+(** Generate the [Arguments] instructions for the trait declarationsin Coq, so
+ that we don't have to provide the implicit arguments when projecting the fields. *)
+let extract_trait_decl_coq_arguments (ctx : extraction_ctx) (fmt : F.formatter)
+ (decl : trait_decl) : unit =
+ (* Generating the [Arguments] instructions is useful only if there are parameters *)
+ let num_params =
+ List.length decl.generics.types
+ + List.length decl.generics.const_generics
+ + List.length decl.generics.trait_clauses
+ in
+ if num_params > 0 then (
+ (* The constructor *)
+ let cons_name = ctx_get_trait_constructor decl.def_id ctx in
+ extract_coq_arguments_instruction ctx fmt cons_name num_params;
+ (* The constants *)
+ List.iter
+ (fun (name, _) ->
+ let item_name = ctx_get_trait_const decl.def_id name ctx in
+ extract_coq_arguments_instruction ctx fmt item_name num_params)
+ decl.consts;
+ (* The types *)
+ List.iter
+ (fun (name, (clauses, _)) ->
+ (* The type *)
+ let item_name = ctx_get_trait_type decl.def_id name ctx in
+ extract_coq_arguments_instruction ctx fmt item_name num_params;
+ (* The type clauses *)
+ List.iter
+ (fun clause ->
+ let item_name =
+ ctx_get_trait_item_clause decl.def_id name clause.clause_id ctx
+ in
+ extract_coq_arguments_instruction ctx fmt item_name num_params)
+ clauses)
+ decl.types;
+ (* The parent clauses *)
+ List.iter
+ (fun clause ->
+ let item_name =
+ ctx_get_trait_parent_clause decl.def_id clause.clause_id ctx
+ in
+ extract_coq_arguments_instruction ctx fmt item_name num_params)
+ decl.parent_clauses;
+ (* The required methods *)
+ List.iter
+ (fun (item_name, id) ->
+ (* Lookup the definition *)
+ let trans = A.FunDeclId.Map.find id ctx.trans_funs in
+ (* Extract the items *)
+ let funs =
+ if trans.keep_fwd then trans.fwd :: trans.backs else trans.backs
+ in
+ let extract_for_method (f : fun_and_loops) =
+ let f = f.f in
+ let item_name =
+ ctx_get_trait_method decl.def_id item_name f.back_id ctx
+ in
+ extract_coq_arguments_instruction ctx fmt item_name num_params
+ in
+ List.iter extract_for_method funs)
+ decl.required_methods;
+ (* Add a space *)
+ F.pp_print_space fmt ())
+
+(** See {!extract_trait_decl_coq_arguments} *)
+let extract_trait_decl_extra_info (ctx : extraction_ctx) (fmt : F.formatter)
+ (trait_decl : trait_decl) : unit =
+ match !backend with
+ | Coq -> extract_trait_decl_coq_arguments ctx fmt trait_decl
+ | _ -> ()
+
+(** Small helper.
+
+ Extract the items for a method in a trait impl.
+ *)
+let extract_trait_impl_method_items (ctx : extraction_ctx) (fmt : F.formatter)
+ (impl : trait_impl) (item_name : string) (id : fun_decl_id)
+ (impl_generics : string list * string list * string list) : unit =
+ let trait_decl_id = impl.impl_trait.trait_decl_id in
+ (* Lookup the definition *)
+ let trans = A.FunDeclId.Map.find id ctx.trans_funs in
+ (* Extract the items *)
+ let funs = if trans.keep_fwd then trans.fwd :: trans.backs else trans.backs in
+ let extract_method (f : fun_and_loops) =
+ let f = f.f in
+ let fun_name = ctx_get_trait_method trait_decl_id item_name f.back_id ctx in
+ let ty () =
+ (* Filter the generics if the method is a builtin *)
+ let i_tys, _, _ = impl_generics in
+ let impl_types, i_tys, f_tys =
+ match FunDeclId.Map.find_opt f.def_id ctx.funs_filter_type_args_map with
+ | None -> (impl.generics.types, i_tys, f.signature.generics.types)
+ | Some filter ->
+ let filter_list filter ls =
+ let ls = List.combine filter ls in
+ List.filter_map (fun (b, ty) -> if b then Some ty else None) ls
+ in
+ let impl_types = impl.generics.types in
+ let impl_filter =
+ Collections.List.prefix (List.length impl_types) filter
+ in
+ let i_tys = i_tys in
+ let i_filter = Collections.List.prefix (List.length i_tys) filter in
+ ( filter_list impl_filter impl_types,
+ filter_list i_filter i_tys,
+ filter_list filter f.signature.generics.types )
+ in
+ let f_generics = { f.signature.generics with types = f_tys } in
+ (* Extract the generics - we need to quantify over the generics which
+ are specific to the method, and call it will all the generics
+ (trait impl + method generics) *)
+ let f_generics =
+ let drop_trait_clauses = true in
+ generic_params_drop_prefix ~drop_trait_clauses
+ { impl.generics with types = impl_types }
+ f_generics
+ in
+ (* Register and print the quantified generics *)
+ let ctx, f_tys, f_cgs, f_tcs = ctx_add_generic_params f_generics ctx in
+ let use_forall = f_generics <> empty_generic_params in
+ extract_generic_params ctx fmt TypeDeclId.Set.empty ~use_forall f_generics
+ f_tys f_cgs f_tcs;
+ if use_forall then F.pp_print_string fmt ",";
+ (* Extract the function call *)
+ F.pp_print_space fmt ();
+ let fun_name = ctx_get_local_function f.def_id None f.back_id ctx in
+ F.pp_print_string fmt fun_name;
+ let all_generics =
+ let _, i_cgs, i_tcs = impl_generics in
+ List.concat [ i_tys; f_tys; i_cgs; f_cgs; i_tcs; f_tcs ]
+ in
+
+ (* Filter the generics if the function is builtin *)
+ List.iter
+ (fun p ->
+ F.pp_print_space fmt ();
+ F.pp_print_string fmt p)
+ all_generics
+ in
+ extract_trait_impl_item ctx fmt fun_name ty
+ in
+ List.iter extract_method funs
+
+(** Extract a trait implementation *)
+let extract_trait_impl (ctx : extraction_ctx) (fmt : F.formatter)
+ (impl : trait_impl) : unit =
+ log#ldebug (lazy ("extract_trait_impl: " ^ Names.name_to_string impl.name));
+ (* Retrieve the impl name *)
+ let impl_name = ctx_get_trait_impl impl.def_id ctx in
+ (* Add a break before *)
+ F.pp_print_break fmt 0 0;
+ (* Print a comment to link the extracted type to its original rust definition *)
+ extract_comment fmt
+ [ "Trait implementation: [" ^ Print.name_to_string impl.name ^ "]" ];
+ F.pp_print_break fmt 0 0;
+
+ (* Open two outer boxes for the definition, so that whenever possible it gets printed on
+ one line and indents are correct.
+
+ There is just an exception with Lean: in this backend, line breaks are important
+ for the parsing, so we always open a vertical box.
+ *)
+ if !Config.backend = Lean then (
+ F.pp_open_vbox fmt 0;
+ F.pp_open_vbox fmt ctx.indent_incr)
+ else (
+ F.pp_open_hvbox fmt 0;
+ F.pp_open_hvbox fmt ctx.indent_incr);
+
+ (* `let (....) : Trait ... =` *)
+ (* Open the box for the name + generics *)
+ F.pp_open_hovbox fmt ctx.indent_incr;
+ (match ctx.fmt.fun_decl_kind_to_qualif SingleNonRec with
+ | Some qualif ->
+ F.pp_print_string fmt qualif;
+ F.pp_print_space fmt ()
+ | None -> ());
+ F.pp_print_string fmt impl_name;
+
+ (* Print the generics *)
+ (* Add the type and const generic params - note that we need those bindings only for the
+ * body translation (they are not top-level) *)
+ let ctx, type_params, cg_params, trait_clauses =
+ ctx_add_generic_params impl.generics ctx
+ in
+ let all_generics = (type_params, cg_params, trait_clauses) in
+ extract_generic_params ctx fmt TypeDeclId.Set.empty impl.generics type_params
+ cg_params trait_clauses;
+
+ (* Print the type *)
+ F.pp_print_space fmt ();
+ F.pp_print_string fmt ":";
+ F.pp_print_space fmt ();
+ extract_trait_decl_ref ctx fmt TypeDeclId.Set.empty false impl.impl_trait;
+
+ (* When checking if the trait impl is empty: we ignore the provided
+ methods, because for now they are extracted separately *)
+ let is_empty = trait_impl_is_empty { impl with provided_methods = [] } in
+
+ F.pp_print_space fmt ();
+ if is_empty && !Config.backend = FStar then (
+ F.pp_print_string fmt "= ()";
+ (* Outer box *)
+ F.pp_close_box fmt ())
+ else if is_empty && !Config.backend = Coq then (
+ (* Coq is not very good at infering constructors *)
+ let cons = ctx_get_trait_constructor impl.impl_trait.trait_decl_id ctx in
+ F.pp_print_string fmt (":= " ^ cons ^ ".");
+ (* Outer box *)
+ F.pp_close_box fmt ())
+ else (
+ if !Config.backend = Lean then F.pp_print_string fmt ":= {"
+ else if !Config.backend = Coq then F.pp_print_string fmt ":= {|"
+ else F.pp_print_string fmt "= {";
+
+ (* Close the box for the name + generics *)
+ F.pp_close_box fmt ();
+
+ (*
+ * Extract the items
+ *)
+ let trait_decl_id = impl.impl_trait.trait_decl_id in
+
+ (* The constants *)
+ List.iter
+ (fun (name, (_, id)) ->
+ let item_name = ctx_get_trait_const trait_decl_id name ctx in
+ let ty () =
+ F.pp_print_space fmt ();
+ F.pp_print_string fmt (ctx_get_global id ctx)
+ in
+
+ extract_trait_impl_item ctx fmt item_name ty)
+ impl.consts;
+
+ (* The types *)
+ List.iter
+ (fun (name, (trait_refs, ty)) ->
+ (* Extract the type *)
+ let item_name = ctx_get_trait_type trait_decl_id name ctx in
+ let ty () =
+ F.pp_print_space fmt ();
+ extract_ty ctx fmt TypeDeclId.Set.empty false ty
+ in
+ extract_trait_impl_item ctx fmt item_name ty;
+ (* Extract the clauses *)
+ TraitClauseId.iteri
+ (fun clause_id trait_ref ->
+ let item_name =
+ ctx_get_trait_item_clause trait_decl_id name clause_id ctx
+ in
+ let ty () =
+ F.pp_print_space fmt ();
+ extract_trait_ref ctx fmt TypeDeclId.Set.empty false trait_ref
+ in
+ extract_trait_impl_item ctx fmt item_name ty)
+ trait_refs)
+ impl.types;
+
+ (* The parent clauses *)
+ TraitClauseId.iteri
+ (fun clause_id trait_ref ->
+ let item_name =
+ ctx_get_trait_parent_clause trait_decl_id clause_id ctx
+ in
+ let ty () =
+ F.pp_print_space fmt ();
+ extract_trait_ref ctx fmt TypeDeclId.Set.empty false trait_ref
+ in
+ extract_trait_impl_item ctx fmt item_name ty)
+ impl.parent_trait_refs;
+
+ (* The required methods *)
+ List.iter
+ (fun (name, id) ->
+ extract_trait_impl_method_items ctx fmt impl name id all_generics)
+ impl.required_methods;
+
+ (* Close the outer boxes for the definition, as well as the brackets *)
+ F.pp_close_box fmt ();
+ if !backend = Coq then (
+ F.pp_print_space fmt ();
+ F.pp_print_string fmt "|}.")
+ else if (not (!backend = FStar)) || not is_empty then (
+ F.pp_print_space fmt ();
+ F.pp_print_string fmt "}"));
+ F.pp_close_box fmt ();
+ (* Add breaks to insert new lines between definitions *)
+ F.pp_print_break fmt 0 0
+
(** Extract a unit test, if the function is a unit function (takes no
parameters, returns unit).
@@ -3735,8 +2632,7 @@ let extract_unit_test_if_unit_fun (ctx : extraction_ctx) (fmt : F.formatter)
(* Check if this is a unit function *)
let sg = def.signature in
if
- sg.type_params = []
- && sg.const_generic_params = []
+ sg.generics = empty_generic_params
&& (sg.inputs = [ mk_unit_ty ] || sg.inputs = [])
&& sg.output = mk_result_ty mk_unit_ty
then (
@@ -3756,12 +2652,8 @@ let extract_unit_test_if_unit_fun (ctx : extraction_ctx) (fmt : F.formatter)
F.pp_print_string fmt "assert_norm";
F.pp_print_space fmt ();
F.pp_print_string fmt "(";
- (* Note that if the function is opaque, the unit test will fail
- because the normalizer will get stuck *)
- let with_opaque_pre = ctx.use_opaque_pre in
let fun_name =
- ctx_get_local_function with_opaque_pre def.def_id def.loop_id
- def.back_id ctx
+ ctx_get_local_function def.def_id def.loop_id def.back_id ctx
in
F.pp_print_string fmt fun_name;
if sg.inputs <> [] then (
@@ -3776,12 +2668,8 @@ let extract_unit_test_if_unit_fun (ctx : extraction_ctx) (fmt : F.formatter)
F.pp_print_string fmt "Check";
F.pp_print_space fmt ();
F.pp_print_string fmt "(";
- (* Note that if the function is opaque, the unit test will fail
- because the normalizer will get stuck *)
- let with_opaque_pre = ctx.use_opaque_pre in
let fun_name =
- ctx_get_local_function with_opaque_pre def.def_id def.loop_id
- def.back_id ctx
+ ctx_get_local_function def.def_id def.loop_id def.back_id ctx
in
F.pp_print_string fmt fun_name;
if sg.inputs <> [] then (
@@ -3793,12 +2681,8 @@ let extract_unit_test_if_unit_fun (ctx : extraction_ctx) (fmt : F.formatter)
F.pp_print_string fmt "#assert";
F.pp_print_space fmt ();
F.pp_print_string fmt "(";
- (* Note that if the function is opaque, the unit test will fail
- because the normalizer will get stuck *)
- let with_opaque_pre = ctx.use_opaque_pre in
let fun_name =
- ctx_get_local_function with_opaque_pre def.def_id def.loop_id
- def.back_id ctx
+ ctx_get_local_function def.def_id def.loop_id def.back_id ctx
in
F.pp_print_string fmt fun_name;
if sg.inputs <> [] then (
@@ -3812,12 +2696,8 @@ let extract_unit_test_if_unit_fun (ctx : extraction_ctx) (fmt : F.formatter)
| HOL4 ->
F.pp_print_string fmt "val _ = assert_return (";
F.pp_print_string fmt "“";
- (* Note that if the function is opaque, the unit test will fail
- because the normalizer will get stuck *)
- let with_opaque_pre = ctx.use_opaque_pre in
let fun_name =
- ctx_get_local_function with_opaque_pre def.def_id def.loop_id
- def.back_id ctx
+ ctx_get_local_function def.def_id def.loop_id def.back_id ctx
in
F.pp_print_string fmt fun_name;
if sg.inputs <> [] then (
diff --git a/compiler/ExtractBase.ml b/compiler/ExtractBase.ml
index d733c763..31b1a447 100644
--- a/compiler/ExtractBase.ml
+++ b/compiler/ExtractBase.ml
@@ -5,9 +5,10 @@ open TranslateCore
module C = Contexts
module RegionVarId = T.RegionVarId
module F = Format
+open ExtractBuiltin
(** The local logger *)
-let log = L.pure_to_extract_log
+let log = L.extract_log
type region_group_info = {
id : RegionGroupId.id;
@@ -21,8 +22,8 @@ type region_group_info = {
*)
}
-module StringSet = Collections.MakeSet (Collections.OrderedString)
-module StringMap = Collections.MakeMap (Collections.OrderedString)
+module StringSet = Collections.StringSet
+module StringMap = Collections.StringMap
type name = Names.name
type type_name = Names.type_name
@@ -77,6 +78,7 @@ type decl_kind =
F*: [val x : Type0]
Coq: [Axiom x : Type.]
*)
+[@@deriving show]
(** Return [true] if the declaration is the last from its group of declarations.
@@ -111,9 +113,9 @@ let decl_is_first_from_group (kind : decl_kind) : bool =
let decl_is_not_last_from_group (kind : decl_kind) : bool =
not (decl_is_last_from_group kind)
-(* TODO: this should a module we give to a functor! *)
+type type_decl_kind = Enum | Struct [@@deriving show]
-type type_decl_kind = Enum | Struct
+(* TODO: this should be a module we give to a functor! *)
(** A formatter's role is twofold:
1. Come up with name suggestions.
@@ -125,6 +127,9 @@ type type_decl_kind = Enum | Struct
snake case, adding prefixes/suffixes, etc.
2. Format some specific terms, like constants.
+
+ TODO: unclear that this is useful now that all the backends are so much
+ entangled in Extract.ml
*)
type formatter = {
bool_name : string;
@@ -239,37 +244,14 @@ type formatter = {
the same purpose as in {!field:fun_name}.
- loop identifier, if this is for a loop
*)
- opaque_pre : unit -> string;
- (** TODO: obsolete, remove.
-
- The prefix to use for opaque definitions.
-
- We need this because for some backends like Lean and Coq, we group
- opaque definitions in module signatures, meaning that using those
- definitions requires to prefix them with a module parameter name (such
- as "opaque_defs.").
-
- For instance, if we have an opaque function [f : int -> int], which
- is used by the non-opaque function [g], we would generate (in Coq):
- {[
- (* The module signature declaring the opaque definitions *)
- module type OpaqueDefs = {
- f_fwd : int -> int
- ... (* Other definitions *)
- }
-
- (* The definitions generated for the non-opaque definitions *)
- module Funs (opaque: OpaqueDefs) = {
- let g ... =
- ...
- opaque_defs.f_fwd
- ...
- }
- ]}
-
- Upon using [f] in [g], we don't directly use the the name "f_fwd",
- but prefix it with the "opaque_defs." identifier.
- *)
+ trait_decl_name : trait_decl -> string;
+ trait_impl_name : trait_decl -> trait_impl -> string;
+ trait_decl_constructor : trait_decl -> string;
+ trait_parent_clause_name : trait_decl -> trait_clause -> string;
+ trait_const_name : trait_decl -> string -> string;
+ trait_type_name : trait_decl -> string -> string;
+ trait_method_name : trait_decl -> string -> string;
+ trait_type_clause_name : trait_decl -> string -> trait_clause -> string;
var_basename : StringSet.t -> string option -> ty -> string;
(** Generates a variable basename.
@@ -288,6 +270,14 @@ type formatter = {
(** Generates a type variable basename. *)
const_generic_var_basename : StringSet.t -> string -> string;
(** Generates a const generic variable basename. *)
+ trait_self_clause_basename : string;
+ trait_clause_basename : StringSet.t -> trait_clause -> string;
+ (** Return a base name for a trait clause. We might add a suffix to prevent
+ collisions.
+
+ In the traduction we explicitely manipulate the trait clause instances,
+ that is we introduce one input variable for each trait clause.
+ *)
append_index : string -> int -> string;
(** Appends an index to a name - we use this to generate unique
names: when doing so, the role of the formatter is just to concatenate
@@ -396,10 +386,60 @@ type id =
| TypeVarId of TypeVarId.id
| ConstGenericVarId of ConstGenericVarId.id
| VarId of VarId.id
+ | TraitDeclId of TraitDeclId.id
+ | TraitImplId of TraitImplId.id
+ | LocalTraitClauseId of TraitClauseId.id
+ | TraitDeclConstructorId of TraitDeclId.id
+ | TraitMethodId of TraitDeclId.id * string * T.RegionGroupId.id option
+ (** Something peculiar with trait methods: because we have to take into
+ account forward/backward functions, we may need to generate fields
+ items per method.
+ *)
+ | TraitItemId of TraitDeclId.id * string
+ (** A trait associated item which is not a method *)
+ | TraitParentClauseId of TraitDeclId.id * TraitClauseId.id
+ | TraitItemClauseId of TraitDeclId.id * string * TraitClauseId.id
+ | TraitSelfClauseId
+ (** Specifically for the clause: [Self : Trait].
+
+ For now, we forbid provided methods (methods in trait declarations
+ with a default implementation) from being overriden in trait implementations.
+ We extract trait provided methods such that they take an instance of
+ the trait as input: this instance is given by the trait self clause.
+
+ For instance:
+ {[
+ //
+ // Rust
+ //
+ trait ToU64 {
+ fn to_u64(&self) -> u64;
+
+ // Provided method
+ fn is_pos(&self) -> bool {
+ self.to_u64() > 0
+ }
+ }
+
+ //
+ // Generated code
+ //
+ struct ToU64 (T : Type) {
+ to_u64 : T -> u64;
+ }
+
+ // The trait self clause
+ // vvvvvvvvvvvvvvvvvvvvvv
+ let is_pos (T : Type) (trait_self : ToU64 T) (self : T) : bool =
+ trait_self.to_u64 self > 0
+ ]}
+ *)
| UnknownId
(** Used for stored various strings like keywords, definitions which
should always be in context, etc. and which can't be linked to one
of the above.
+
+ TODO: rename to "keyword"
*)
[@@deriving show, ord]
@@ -429,69 +469,64 @@ type names_map = {
precisely which identifiers are mapped to the same name...
*)
names_set : StringSet.t;
- opaque_ids : IdSet.t;
- (** TODO: this is obsolete. Remove.
+}
- The set of opaque definitions.
+let empty_names_map : names_map =
+ {
+ id_to_name = IdMap.empty;
+ name_to_id = StringMap.empty;
+ names_set = StringSet.empty;
+ }
- See {!formatter.opaque_pre} for detailed explanations about why
- we need to know which definitions are opaque to compute names.
+(** Small helper to report name collision *)
+let report_name_collision (id_to_string : id -> string) (id1 : id) (id2 : id)
+ (name : string) : unit =
+ let id1 = "\n- " ^ id_to_string id1 in
+ let id2 = "\n- " ^ id_to_string id2 in
+ let err =
+ "Name clash detected: the following identifiers are bound to the same name \
+ \"" ^ name ^ "\":" ^ id1 ^ id2
+ ^ "\nYou may want to rename some of your definitions, or report an issue."
+ in
+ log#serror err;
+ (* If we fail hard on errors, raise an exception *)
+ if !Config.fail_hard then raise (Failure err)
- Also note that the opaque ids don't contain the ids of the assumed
- definitions. In practice, assumed definitions are opaque_defs. However, they
- are not grouped in the opaque module, meaning we never need to
- prefix them (with, say, "opaque_defs."): we thus consider them as non-opaque
- with regards to the names map.
- *)
-}
+let names_map_get_id_from_name (name : string) (nm : names_map) : id option =
+ StringMap.find_opt name nm.name_to_id
-let names_map_add (id_to_string : id -> string) (is_opaque : bool) (id : id)
- (name : string) (nm : names_map) : names_map =
- (* Check if there is a clash *)
- (match StringMap.find_opt name nm.name_to_id with
+let names_map_check_collision (id_to_string : id -> string) (id : id)
+ (name : string) (nm : names_map) : unit =
+ match names_map_get_id_from_name name nm with
| None -> () (* Ok *)
| Some clash ->
(* There is a clash: print a nice debugging message for the user *)
- let id1 = "\n- " ^ id_to_string clash in
- let id2 = "\n- " ^ id_to_string id in
- let err =
- "Name clash detected: the following identifiers are bound to the same \
- name \"" ^ name ^ "\":" ^ id1 ^ id2
- in
- log#serror err;
- raise (Failure err));
- (* Sanity check *)
- assert (not (StringSet.mem name nm.names_set));
+ report_name_collision id_to_string clash id name
+
+(** Insert bindings in a names map without checking for collisions *)
+let names_map_add_unchecked (id : id) (name : string) (nm : names_map) :
+ names_map =
(* Insert *)
let id_to_name = IdMap.add id name nm.id_to_name in
let name_to_id = StringMap.add name id nm.name_to_id in
let names_set = StringSet.add name nm.names_set in
- let opaque_ids =
- if is_opaque then IdSet.add id nm.opaque_ids else nm.opaque_ids
- in
- { id_to_name; name_to_id; names_set; opaque_ids }
-
-let names_map_add_assumed_type (id_to_string : id -> string) (id : assumed_ty)
- (name : string) (nm : names_map) : names_map =
- let is_opaque = false in
- names_map_add id_to_string is_opaque (TypeId (Assumed id)) name nm
-
-let names_map_add_assumed_struct (id_to_string : id -> string) (id : assumed_ty)
- (name : string) (nm : names_map) : names_map =
- let is_opaque = false in
- names_map_add id_to_string is_opaque (StructId (Assumed id)) name nm
+ { id_to_name; name_to_id; names_set }
-let names_map_add_assumed_variant (id_to_string : id -> string)
- (id : assumed_ty) (variant_id : VariantId.id) (name : string)
+let names_map_add (id_to_string : id -> string) (id : id) (name : string)
(nm : names_map) : names_map =
- let is_opaque = false in
- names_map_add id_to_string is_opaque
- (VariantId (Assumed id, variant_id))
- name nm
-
-let names_map_add_function (id_to_string : id -> string) (is_opaque : bool)
- (fid : fun_id) (name : string) (nm : names_map) : names_map =
- names_map_add id_to_string is_opaque (FunId fid) name nm
+ (* Check if there is a clash *)
+ names_map_check_collision id_to_string id name nm;
+ (* Sanity check *)
+ if StringSet.mem name nm.names_set then (
+ let err =
+ "Error when registering the name for id: " ^ id_to_string id
+ ^ ":\nThe chosen name is already in the names set: " ^ name
+ in
+ log#serror err;
+ (* If we fail hard on errors, raise an exception *)
+ if !Config.fail_hard then raise (Failure err));
+ (* Insert *)
+ names_map_add_unchecked id name nm
(** The unsafe names map stores mappings from identifiers to names which might
collide. For some backends and some names, it might be acceptable to have
@@ -503,6 +538,8 @@ let names_map_add_function (id_to_string : id -> string) (is_opaque : bool)
*)
type unsafe_names_map = { id_to_name : string IdMap.t }
+let empty_unsafe_names_map = { id_to_name = IdMap.empty }
+
let unsafe_names_map_add (id : id) (name : string) (nm : unsafe_names_map) :
unsafe_names_map =
{ id_to_name = IdMap.add id name nm.id_to_name }
@@ -541,6 +578,24 @@ let basename_to_unique (names_set : StringSet.t)
type fun_name_info = { keep_fwd : bool; num_backs : int }
+type names_maps = {
+ names_map : names_map;
+ (** The map for id to names, where we forbid name collisions
+ (ex.: we always forbid function name collisions). *)
+ unsafe_names_map : unsafe_names_map;
+ (** The map for id to names, where we allow name collisions
+ (ex.: we might allow record field name collisions). *)
+ strict_names_map : names_map;
+ (** This map is a sub-map of [names_map]. For the ids in this map we also
+ forbid collisions with names in the [unsafe_names_map].
+
+ We do so for keywords for instance, but also for types (in a dependently
+ typed language, we might have an issue if the field of a record has, say,
+ the name "u32", and another field of the same record refers to "u32"
+ (for instance in its type).
+ *)
+}
+
(** Extraction context.
Note that the extraction context contains information coming from the
@@ -549,24 +604,12 @@ type fun_name_info = { keep_fwd : bool; num_backs : int }
functions, etc.
*)
type extraction_ctx = {
+ crate : A.crate;
trans_ctx : trans_ctx;
- names_map : names_map;
- (** The map for id to names, where we forbid name collisions
- (ex.: we always forbid function name collisions). *)
- unsafe_names_map : unsafe_names_map;
- (** The map for id to names, where we allow name collisions
- (ex.: we might allow record field name collisions). *)
+ names_maps : names_maps;
fmt : formatter;
indent_incr : int;
(** The indent increment we insert whenever we need to indent more *)
- use_opaque_pre : bool;
- (** Do we use the "opaque_defs." prefix for the opaque definitions?
-
- Opaque function definitions might refer opaque types: if we are in the
- opaque module, we musn't use the "opaque_defs." prefix, otherwise we
- use it.
- Also see {!names_map.opaque_ids}.
- *)
use_dep_ite : bool;
(** For Lean: do we use dependent-if then else expressions?
@@ -586,6 +629,29 @@ type extraction_ctx = {
in case a Rust function only has one backward translation
and we filter the forward function because it returns unit.
*)
+ trait_decl_id : trait_decl_id option;
+ (** If we are extracting a trait declaration, identifies it *)
+ is_provided_method : bool;
+ trans_types : Pure.type_decl Pure.TypeDeclId.Map.t;
+ trans_funs : pure_fun_translation A.FunDeclId.Map.t;
+ functions_with_decreases_clause : PureUtils.FunLoopIdSet.t;
+ trans_trait_decls : Pure.trait_decl Pure.TraitDeclId.Map.t;
+ trans_trait_impls : Pure.trait_impl Pure.TraitImplId.Map.t;
+ types_filter_type_args_map : bool list TypeDeclId.Map.t;
+ (** The map to filter the type arguments for the builtin type
+ definitions.
+
+ We need this for type `Vec`, for instance, which takes a useless
+ (in the context of the type translation) type argument for the
+ allocator which is used, and which we want to remove.
+
+ TODO: it would be cleaner to filter those types in a micro-pass,
+ rather than at code generation time.
+ *)
+ funs_filter_type_args_map : bool list FunDeclId.Map.t;
+ (** Same as {!types_filter_type_args_map}, but for functions *)
+ trait_impls_filter_type_args_map : bool list TraitImplId.Map.t;
+ (** Same as {!types_filter_type_args_map}, but for trait implementations *)
}
(** Debugging function, used when communicating name collisions to the user,
@@ -593,9 +659,16 @@ type extraction_ctx = {
instance).
*)
let id_to_string (id : id) (ctx : extraction_ctx) : string =
- let global_decls = ctx.trans_ctx.global_context.global_decls in
- let fun_decls = ctx.trans_ctx.fun_context.fun_decls in
- let type_decls = ctx.trans_ctx.type_context.type_decls in
+ let global_decls = ctx.trans_ctx.global_ctx.global_decls in
+ let fun_decls = ctx.trans_ctx.fun_ctx.fun_decls in
+ let type_decls = ctx.trans_ctx.type_ctx.type_decls in
+ let trait_decls = ctx.trans_ctx.trait_decls_ctx.trait_decls in
+ let trait_decl_id_to_string (id : A.TraitDeclId.id) : string =
+ let trait_name =
+ Print.fun_name_to_string (A.TraitDeclId.Map.find id trait_decls).name
+ in
+ "trait_decl: " ^ trait_name ^ " (id: " ^ A.TraitDeclId.to_string id ^ ")"
+ in
(* TODO: factorize the pretty-printing with what is in PrintPure *)
let get_type_name (id : type_id) : string =
match id with
@@ -614,10 +687,17 @@ let id_to_string (id : id) (ctx : extraction_ctx) : string =
| FromLlbc (fid, lp_id, rg_id) ->
let fun_name =
match fid with
- | Regular fid ->
+ | FunId (Regular fid) ->
Print.fun_name_to_string
(A.FunDeclId.Map.find fid fun_decls).name
- | Assumed aid -> A.show_assumed_fun_id aid
+ | FunId (Assumed aid) -> A.show_assumed_fun_id aid
+ | TraitMethod (trait_ref, method_name, _) ->
+ (* Shouldn't happen *)
+ if !Config.fail_hard then raise (Failure "Unexpected")
+ else
+ "Trait method: decl: "
+ ^ TraitDeclId.to_string trait_ref.trait_decl_ref.trait_decl_id
+ ^ ", method_name: " ^ method_name
in
let lp_kind =
@@ -673,12 +753,16 @@ let id_to_string (id : id) (ctx : extraction_ctx) : string =
if variant_id = error_failure_id then "@error::Failure"
else if variant_id = error_out_of_fuel_id then "@error::OutOfFuel"
else raise (Failure "Unreachable")
- | Assumed Option ->
- if variant_id = option_some_id then "@option::Some"
- else if variant_id = option_none_id then "@option::None"
+ | Assumed Fuel ->
+ if variant_id = fuel_zero_id then "@fuel::0"
+ else if variant_id = fuel_succ_id then "@fuel::Succ"
else raise (Failure "Unreachable")
- | Assumed (State | Vec | Fuel | Array | Slice | Str | Range) ->
- raise (Failure "Unreachable")
+ | Assumed (State | Array | Slice | Str | RawPtr _) ->
+ raise
+ (Failure
+ ("Unreachable: variant id ("
+ ^ VariantId.to_string variant_id
+ ^ ") for " ^ show_type_id id))
| AdtId id -> (
let def = TypeDeclId.Map.find id type_decls in
match def.kind with
@@ -693,8 +777,7 @@ let id_to_string (id : id) (ctx : extraction_ctx) : string =
match id with
| Tuple -> raise (Failure "Unreachable")
| Assumed
- ( State | Result | Error | Fuel | Option | Vec | Array | Slice | Str
- | Range ) ->
+ (State | Result | Error | Fuel | Array | Slice | Str | RawPtr _) ->
(* We can't directly have access to the fields of those types *)
raise (Failure "Unreachable")
| AdtId id -> (
@@ -716,134 +799,265 @@ let id_to_string (id : id) (ctx : extraction_ctx) : string =
| ConstGenericVarId id ->
"const_generic_var_id: " ^ ConstGenericVarId.to_string id
| VarId id -> "var_id: " ^ VarId.to_string id
+ | TraitDeclId id -> "trait_decl_id: " ^ TraitDeclId.to_string id
+ | TraitImplId id -> "trait_impl_id: " ^ TraitImplId.to_string id
+ | LocalTraitClauseId id ->
+ "local_trait_clause_id: " ^ TraitClauseId.to_string id
+ | TraitDeclConstructorId id ->
+ "trait_decl_constructor: " ^ trait_decl_id_to_string id
+ | TraitParentClauseId (id, clause_id) ->
+ "trait_parent_clause_id: " ^ trait_decl_id_to_string id ^ ", clause_id: "
+ ^ TraitClauseId.to_string clause_id
+ | TraitItemClauseId (id, item_name, clause_id) ->
+ "trait_item_clause_id: " ^ trait_decl_id_to_string id ^ ", item name: "
+ ^ item_name ^ ", clause_id: "
+ ^ TraitClauseId.to_string clause_id
+ | TraitItemId (id, name) ->
+ "trait_item_id: " ^ trait_decl_id_to_string id ^ ", type name: " ^ name
+ | TraitMethodId (trait_decl_id, fun_name, rg_id) ->
+ let fwd_back_kind =
+ match rg_id with
+ | None -> "forward"
+ | Some rg_id -> "backward " ^ RegionGroupId.to_string rg_id
+ in
+ trait_decl_id_to_string trait_decl_id
+ ^ ", method name (" ^ fwd_back_kind ^ "): " ^ fun_name
+ | TraitSelfClauseId -> "trait_self_clause"
+
+(** Return [true] if we are strict on collisions for this id (i.e., we forbid
+ collisions even with the ids in the unsafe names map) *)
+let strict_collisions (id : id) : bool =
+ match id with UnknownId | TypeId _ -> true | _ -> false
(** We might not check for collisions for some specific ids (ex.: field names) *)
let allow_collisions (id : id) : bool =
match id with
- | FieldId (_, _) -> !Config.record_fields_short_names
+ | FieldId _ | TraitItemClauseId _ | TraitParentClauseId _ | TraitItemId _
+ | TraitMethodId _ ->
+ !Config.record_fields_short_names
+ | FunId (Pure _ | FromLlbc (FunId (Assumed _), _, _)) ->
+ (* We map several assumed functions to the same id *)
+ true
| _ -> false
-let ctx_add (is_opaque : bool) (id : id) (name : string) (ctx : extraction_ctx)
- : extraction_ctx =
- (* We do not use the same name map if we allow/disallow collisions *)
+(** The [id_to_string] function to print nice debugging messages if there are
+ collisions *)
+let names_maps_add (id_to_string : id -> string) (id : id) (name : string)
+ (nm : names_maps) : names_maps =
+ (* We do not use the same name map if we allow/disallow collisions.
+ We notably use it for field names: some backends like Lean can use the
+ type information to disambiguate field projections.
+
+ Remark: we still need to check that those "unsafe" ids don't collide with
+ the ids that we mark as "strict on collision".
+
+ For instance, we don't allow naming a field "let". We enforce this by
+ not checking collision between ids for which we permit collisions (ex.:
+ between fields), but still checking collisions between those ids and the
+ others (ex.: fields and keywords).
+ *)
if allow_collisions id then (
- assert (not is_opaque);
+ (* Check with the ids which are considered to be strict on collisions *)
+ names_map_check_collision id_to_string id name nm.strict_names_map;
{
- ctx with
- unsafe_names_map = unsafe_names_map_add id name ctx.unsafe_names_map;
+ nm with
+ unsafe_names_map = unsafe_names_map_add id name nm.unsafe_names_map;
})
else
- (* The id_to_string function to print nice debugging messages if there are
- * collisions *)
- let id_to_string (id : id) : string = id_to_string id ctx in
- let names_map =
- names_map_add id_to_string is_opaque id name ctx.names_map
+ (* Remark: if we are strict on collisions:
+ - we add the id to the strict collisions map
+ - we check that the id doesn't collide with the unsafe map
+ TODO: we might not check that:
+ - a user defined function doesn't collide with an assumed function
+ - two trait decl items don't collide with each other
+ *)
+ let strict_names_map =
+ if strict_collisions id then
+ names_map_add id_to_string id name nm.strict_names_map
+ else nm.strict_names_map
in
- { ctx with names_map }
+ let names_map = names_map_add id_to_string id name nm.names_map in
+ { nm with strict_names_map; names_map }
+
+let ctx_add (id : id) (name : string) (ctx : extraction_ctx) : extraction_ctx =
+ let id_to_string (id : id) : string = id_to_string id ctx in
+ let names_maps = names_maps_add id_to_string id name ctx.names_maps in
+ { ctx with names_maps }
-(** [with_opaque_pre]: if [true] and the definition is opaque, add the opaque prefix *)
-let ctx_get (with_opaque_pre : bool) (id : id) (ctx : extraction_ctx) : string =
+(** The [id_to_string] function to print nice debugging messages if there are
+ collisions *)
+let names_maps_get (id_to_string : id -> string) (id : id) (nm : names_maps) :
+ string =
(* We do not use the same name map if we allow/disallow collisions *)
- if allow_collisions id then IdMap.find id ctx.unsafe_names_map.id_to_name
+ let map_to_string (m : string IdMap.t) : string =
+ "[\n"
+ ^ String.concat ","
+ (List.map
+ (fun (id, n) -> "\n " ^ id_to_string id ^ " -> " ^ n)
+ (IdMap.bindings m))
+ ^ "\n]"
+ in
+ if allow_collisions id then (
+ let m = nm.unsafe_names_map.id_to_name in
+ match IdMap.find_opt id m with
+ | Some s -> s
+ | None ->
+ let err =
+ "Could not find: " ^ id_to_string id ^ "\nNames map:\n"
+ ^ map_to_string m
+ in
+ log#serror err;
+ if !Config.fail_hard then raise (Failure err)
+ else "(%%%ERROR: unknown identifier\": " ^ id_to_string id ^ "\"%%%)")
else
- match IdMap.find_opt id ctx.names_map.id_to_name with
- | Some s ->
- let is_opaque = IdSet.mem id ctx.names_map.opaque_ids in
- if with_opaque_pre && is_opaque then ctx.fmt.opaque_pre () ^ s else s
+ let m = nm.names_map.id_to_name in
+ match IdMap.find_opt id m with
+ | Some s -> s
| None ->
- log#serror ("Could not find: " ^ id_to_string id ctx);
- raise Not_found
+ let err =
+ "Could not find: " ^ id_to_string id ^ "\nNames map:\n"
+ ^ map_to_string m
+ in
+ log#serror err;
+ if !Config.fail_hard then raise (Failure err)
+ else "(ERROR: \"" ^ id_to_string id ^ "\")"
+
+let ctx_get (id : id) (ctx : extraction_ctx) : string =
+ let id_to_string (id : id) : string = id_to_string id ctx in
+ names_maps_get id_to_string id ctx.names_maps
+
+let names_maps_add_assumed_type (id_to_string : id -> string) (id : assumed_ty)
+ (name : string) (nm : names_maps) : names_maps =
+ names_maps_add id_to_string (TypeId (Assumed id)) name nm
+
+let names_maps_add_assumed_struct (id_to_string : id -> string)
+ (id : assumed_ty) (name : string) (nm : names_maps) : names_maps =
+ names_maps_add id_to_string (StructId (Assumed id)) name nm
-let ctx_get_global (with_opaque_pre : bool) (id : A.GlobalDeclId.id)
+let names_maps_add_assumed_variant (id_to_string : id -> string)
+ (id : assumed_ty) (variant_id : VariantId.id) (name : string)
+ (nm : names_maps) : names_maps =
+ names_maps_add id_to_string (VariantId (Assumed id, variant_id)) name nm
+
+let names_maps_add_function (id_to_string : id -> string) (fid : fun_id)
+ (name : string) (nm : names_maps) : names_maps =
+ names_maps_add id_to_string (FunId fid) name nm
+
+let ctx_get_global (id : A.GlobalDeclId.id) (ctx : extraction_ctx) : string =
+ ctx_get (GlobalId id) ctx
+
+let ctx_get_function (id : fun_id) (ctx : extraction_ctx) : string =
+ ctx_get (FunId id) ctx
+
+let ctx_get_local_function (id : A.FunDeclId.id) (lp : LoopId.id option)
+ (rg : RegionGroupId.id option) (ctx : extraction_ctx) : string =
+ ctx_get_function (FromLlbc (FunId (Regular id), lp, rg)) ctx
+
+let ctx_get_type (id : type_id) (ctx : extraction_ctx) : string =
+ assert (id <> Tuple);
+ ctx_get (TypeId id) ctx
+
+let ctx_get_local_type (id : TypeDeclId.id) (ctx : extraction_ctx) : string =
+ ctx_get_type (AdtId id) ctx
+
+let ctx_get_assumed_type (id : assumed_ty) (ctx : extraction_ctx) : string =
+ ctx_get_type (Assumed id) ctx
+
+let ctx_get_trait_constructor (id : trait_decl_id) (ctx : extraction_ctx) :
+ string =
+ ctx_get (TraitDeclConstructorId id) ctx
+
+let ctx_get_trait_self_clause (ctx : extraction_ctx) : string =
+ ctx_get TraitSelfClauseId ctx
+
+let ctx_get_trait_decl (id : trait_decl_id) (ctx : extraction_ctx) : string =
+ ctx_get (TraitDeclId id) ctx
+
+let ctx_get_trait_impl (id : trait_impl_id) (ctx : extraction_ctx) : string =
+ ctx_get (TraitImplId id) ctx
+
+let ctx_get_trait_item (id : trait_decl_id) (item_name : string)
(ctx : extraction_ctx) : string =
- ctx_get with_opaque_pre (GlobalId id) ctx
+ ctx_get (TraitItemId (id, item_name)) ctx
-let ctx_get_function (with_opaque_pre : bool) (id : fun_id)
+let ctx_get_trait_const (id : trait_decl_id) (item_name : string)
(ctx : extraction_ctx) : string =
- ctx_get with_opaque_pre (FunId id) ctx
+ ctx_get_trait_item id item_name ctx
-let ctx_get_local_function (with_opaque_pre : bool) (id : A.FunDeclId.id)
- (lp : LoopId.id option) (rg : RegionGroupId.id option)
+let ctx_get_trait_type (id : trait_decl_id) (item_name : string)
(ctx : extraction_ctx) : string =
- ctx_get_function with_opaque_pre (FromLlbc (Regular id, lp, rg)) ctx
+ ctx_get_trait_item id item_name ctx
-let ctx_get_type (with_opaque_pre : bool) (id : type_id) (ctx : extraction_ctx)
- : string =
- assert (id <> Tuple);
- ctx_get with_opaque_pre (TypeId id) ctx
+let ctx_get_trait_method (id : trait_decl_id) (item_name : string)
+ (rg_id : T.RegionGroupId.id option) (ctx : extraction_ctx) : string =
+ ctx_get (TraitMethodId (id, item_name, rg_id)) ctx
-let ctx_get_local_type (with_opaque_pre : bool) (id : TypeDeclId.id)
+let ctx_get_trait_parent_clause (id : trait_decl_id) (clause : trait_clause_id)
(ctx : extraction_ctx) : string =
- ctx_get_type with_opaque_pre (AdtId id) ctx
+ ctx_get (TraitParentClauseId (id, clause)) ctx
-let ctx_get_assumed_type (id : assumed_ty) (ctx : extraction_ctx) : string =
- (* In practice, the assumed types are opaque. However, assumed types
- are never grouped in the opaque module, meaning we never need to
- prefix them: we thus consider them as non-opaque with regards to the
- names map.
- *)
- let is_opaque = false in
- ctx_get_type is_opaque (Assumed id) ctx
+let ctx_get_trait_item_clause (id : trait_decl_id) (item : string)
+ (clause : trait_clause_id) (ctx : extraction_ctx) : string =
+ ctx_get (TraitItemClauseId (id, item, clause)) ctx
let ctx_get_var (id : VarId.id) (ctx : extraction_ctx) : string =
- let is_opaque = false in
- ctx_get is_opaque (VarId id) ctx
+ ctx_get (VarId id) ctx
let ctx_get_type_var (id : TypeVarId.id) (ctx : extraction_ctx) : string =
- let is_opaque = false in
- ctx_get is_opaque (TypeVarId id) ctx
+ ctx_get (TypeVarId id) ctx
let ctx_get_const_generic_var (id : ConstGenericVarId.id) (ctx : extraction_ctx)
: string =
- let is_opaque = false in
- ctx_get is_opaque (ConstGenericVarId id) ctx
+ ctx_get (ConstGenericVarId id) ctx
+
+let ctx_get_local_trait_clause (id : TraitClauseId.id) (ctx : extraction_ctx) :
+ string =
+ ctx_get (LocalTraitClauseId id) ctx
let ctx_get_field (type_id : type_id) (field_id : FieldId.id)
(ctx : extraction_ctx) : string =
- let is_opaque = false in
- ctx_get is_opaque (FieldId (type_id, field_id)) ctx
+ ctx_get (FieldId (type_id, field_id)) ctx
-let ctx_get_struct (with_opaque_pre : bool) (def_id : type_id)
- (ctx : extraction_ctx) : string =
- ctx_get with_opaque_pre (StructId def_id) ctx
+let ctx_get_struct (def_id : type_id) (ctx : extraction_ctx) : string =
+ ctx_get (StructId def_id) ctx
let ctx_get_variant (def_id : type_id) (variant_id : VariantId.id)
(ctx : extraction_ctx) : string =
- let is_opaque = false in
- ctx_get is_opaque (VariantId (def_id, variant_id)) ctx
+ ctx_get (VariantId (def_id, variant_id)) ctx
let ctx_get_decreases_proof (def_id : A.FunDeclId.id)
(loop_id : LoopId.id option) (ctx : extraction_ctx) : string =
- let is_opaque = false in
- ctx_get is_opaque (DecreasesProofId (Regular def_id, loop_id)) ctx
+ ctx_get (DecreasesProofId (Regular def_id, loop_id)) ctx
let ctx_get_termination_measure (def_id : A.FunDeclId.id)
(loop_id : LoopId.id option) (ctx : extraction_ctx) : string =
- let is_opaque = false in
- ctx_get is_opaque (TerminationMeasureId (Regular def_id, loop_id)) ctx
+ ctx_get (TerminationMeasureId (Regular def_id, loop_id)) ctx
(** Generate a unique type variable name and add it to the context *)
let ctx_add_type_var (basename : string) (id : TypeVarId.id)
(ctx : extraction_ctx) : extraction_ctx * string =
- let is_opaque = false in
- let name = ctx.fmt.type_var_basename ctx.names_map.names_set basename in
let name =
- basename_to_unique ctx.names_map.names_set ctx.fmt.append_index name
+ ctx.fmt.type_var_basename ctx.names_maps.names_map.names_set basename
+ in
+ let name =
+ basename_to_unique ctx.names_maps.names_map.names_set ctx.fmt.append_index
+ name
in
- let ctx = ctx_add is_opaque (TypeVarId id) name ctx in
+ let ctx = ctx_add (TypeVarId id) name ctx in
(ctx, name)
(** Generate a unique const generic variable name and add it to the context *)
let ctx_add_const_generic_var (basename : string) (id : ConstGenericVarId.id)
(ctx : extraction_ctx) : extraction_ctx * string =
- let is_opaque = false in
let name =
- ctx.fmt.const_generic_var_basename ctx.names_map.names_set basename
+ ctx.fmt.const_generic_var_basename ctx.names_maps.names_map.names_set
+ basename
in
let name =
- basename_to_unique ctx.names_map.names_set ctx.fmt.append_index name
+ basename_to_unique ctx.names_maps.names_map.names_set ctx.fmt.append_index
+ name
in
- let ctx = ctx_add is_opaque (ConstGenericVarId id) name ctx in
+ let ctx = ctx_add (ConstGenericVarId id) name ctx in
(ctx, name)
(** See {!ctx_add_type_var} *)
@@ -856,11 +1070,31 @@ let ctx_add_type_vars (vars : (string * TypeVarId.id) list)
(** Generate a unique variable name and add it to the context *)
let ctx_add_var (basename : string) (id : VarId.id) (ctx : extraction_ctx) :
extraction_ctx * string =
- let is_opaque = false in
let name =
- basename_to_unique ctx.names_map.names_set ctx.fmt.append_index basename
+ basename_to_unique ctx.names_maps.names_map.names_set ctx.fmt.append_index
+ basename
in
- let ctx = ctx_add is_opaque (VarId id) name ctx in
+ let ctx = ctx_add (VarId id) name ctx in
+ (ctx, name)
+
+(** Generate a unique variable name for the trait self clause and add it to the context *)
+let ctx_add_trait_self_clause (ctx : extraction_ctx) : extraction_ctx * string =
+ let basename = ctx.fmt.trait_self_clause_basename in
+ let name =
+ basename_to_unique ctx.names_maps.names_map.names_set ctx.fmt.append_index
+ basename
+ in
+ let ctx = ctx_add TraitSelfClauseId name ctx in
+ (ctx, name)
+
+(** Generate a unique trait clause name and add it to the context *)
+let ctx_add_local_trait_clause (basename : string) (id : TraitClauseId.id)
+ (ctx : extraction_ctx) : extraction_ctx * string =
+ let name =
+ basename_to_unique ctx.names_maps.names_map.names_set ctx.fmt.append_index
+ basename
+ in
+ let ctx = ctx_add (LocalTraitClauseId id) name ctx in
(ctx, name)
(** See {!ctx_add_var} *)
@@ -868,7 +1102,9 @@ let ctx_add_vars (vars : var list) (ctx : extraction_ctx) :
extraction_ctx * string list =
List.fold_left_map
(fun ctx (v : var) ->
- let name = ctx.fmt.var_basename ctx.names_map.names_set v.basename v.ty in
+ let name =
+ ctx.fmt.var_basename ctx.names_maps.names_map.names_set v.basename v.ty
+ in
ctx_add_var name v.id ctx)
ctx vars
@@ -885,142 +1121,105 @@ let ctx_add_const_generic_params (vars : const_generic_var list)
ctx_add_const_generic_var var.name var.index ctx)
ctx vars
-let ctx_add_type_const_generic_params (tvars : type_var list)
- (cgvars : const_generic_var list) (ctx : extraction_ctx) :
- extraction_ctx * string list * string list =
- let ctx, tys = ctx_add_type_params tvars ctx in
- let ctx, cgs = ctx_add_const_generic_params cgvars ctx in
- (ctx, tys, cgs)
-
-let ctx_add_type_decl_struct (def : type_decl) (ctx : extraction_ctx) :
- extraction_ctx * string =
- assert (match def.kind with Struct _ -> true | _ -> false);
- let is_opaque = false in
- let cons_name = ctx.fmt.struct_constructor def.name in
- let ctx = ctx_add is_opaque (StructId (AdtId def.def_id)) cons_name ctx in
- (ctx, cons_name)
-
-let ctx_add_type_decl (def : type_decl) (ctx : extraction_ctx) : extraction_ctx
- =
- let is_opaque = def.kind = Opaque in
- let def_name = ctx.fmt.type_name def.name in
- let ctx = ctx_add is_opaque (TypeId (AdtId def.def_id)) def_name ctx in
- ctx
-
-let ctx_add_field (def : type_decl) (field_id : FieldId.id) (field : field)
- (ctx : extraction_ctx) : extraction_ctx * string =
- let is_opaque = false in
- let name = ctx.fmt.field_name def.name field_id field.field_name in
- let ctx = ctx_add is_opaque (FieldId (AdtId def.def_id, field_id)) name ctx in
- (ctx, name)
-
-let ctx_add_fields (def : type_decl) (fields : (FieldId.id * field) list)
+let ctx_add_local_trait_clauses (clauses : trait_clause list)
(ctx : extraction_ctx) : extraction_ctx * string list =
List.fold_left_map
- (fun ctx (vid, v) -> ctx_add_field def vid v ctx)
- ctx fields
-
-let ctx_add_variant (def : type_decl) (variant_id : VariantId.id)
- (variant : variant) (ctx : extraction_ctx) : extraction_ctx * string =
- let is_opaque = false in
- let name = ctx.fmt.variant_name def.name variant.variant_name in
- (* Add the type name prefix for Lean *)
- let name =
- if !Config.backend = Lean then
- let type_name = ctx.fmt.type_name def.name in
- type_name ^ "." ^ name
- else name
- in
- let ctx =
- ctx_add is_opaque (VariantId (AdtId def.def_id, variant_id)) name ctx
- in
- (ctx, name)
-
-let ctx_add_variants (def : type_decl)
- (variants : (VariantId.id * variant) list) (ctx : extraction_ctx) :
- extraction_ctx * string list =
- List.fold_left_map
- (fun ctx (vid, v) -> ctx_add_variant def vid v ctx)
- ctx variants
+ (fun ctx (c : trait_clause) ->
+ let basename =
+ ctx.fmt.trait_clause_basename ctx.names_maps.names_map.names_set c
+ in
+ ctx_add_local_trait_clause basename c.clause_id ctx)
+ ctx clauses
-let ctx_add_struct (def : type_decl) (ctx : extraction_ctx) :
- extraction_ctx * string =
- assert (match def.kind with Struct _ -> true | _ -> false);
- let is_opaque = false in
- let name = ctx.fmt.struct_constructor def.name in
- let ctx = ctx_add is_opaque (StructId (AdtId def.def_id)) name ctx in
- (ctx, name)
+(** Returns the lists of names for:
+ - the type variables
+ - the const generic variables
+ - the trait clauses
+ *)
+let ctx_add_generic_params (generics : generic_params) (ctx : extraction_ctx) :
+ extraction_ctx * string list * string list * string list =
+ let { types; const_generics; trait_clauses } = generics in
+ let ctx, tys = ctx_add_type_params types ctx in
+ let ctx, cgs = ctx_add_const_generic_params const_generics ctx in
+ let ctx, tcs = ctx_add_local_trait_clauses trait_clauses ctx in
+ (ctx, tys, cgs, tcs)
let ctx_add_decreases_proof (def : fun_decl) (ctx : extraction_ctx) :
extraction_ctx =
- let is_opaque = false in
let name =
ctx.fmt.decreases_proof_name def.def_id def.basename def.num_loops
def.loop_id
in
- ctx_add is_opaque
- (DecreasesProofId (Regular def.def_id, def.loop_id))
- name ctx
+ ctx_add (DecreasesProofId (Regular def.def_id, def.loop_id)) name ctx
let ctx_add_termination_measure (def : fun_decl) (ctx : extraction_ctx) :
extraction_ctx =
- let is_opaque = false in
let name =
ctx.fmt.termination_measure_name def.def_id def.basename def.num_loops
def.loop_id
in
- ctx_add is_opaque
- (TerminationMeasureId (Regular def.def_id, def.loop_id))
- name ctx
+ ctx_add (TerminationMeasureId (Regular def.def_id, def.loop_id)) name ctx
let ctx_add_global_decl_and_body (def : A.global_decl) (ctx : extraction_ctx) :
extraction_ctx =
(* TODO: update once the body id can be an option *)
- let is_opaque = false in
- let name = ctx.fmt.global_name def.name in
let decl = GlobalId def.def_id in
- let body = FunId (FromLlbc (Regular def.body_id, None, None)) in
- let ctx = ctx_add is_opaque decl (name ^ "_c") ctx in
- let ctx = ctx_add is_opaque body (name ^ "_body") ctx in
- ctx
-let ctx_add_fun_decl (trans_group : bool * pure_fun_translation)
- (def : fun_decl) (ctx : extraction_ctx) : extraction_ctx =
- (* Sanity check: the function should not be a global body - those are handled
- * separately *)
- assert (not def.is_global_decl_body);
+ (* Check if the global corresponds to an assumed global that we should map
+ to a custom definition in our standard library (for instance, happens
+ with "core::num::usize::MAX") *)
+ let sname = name_to_simple_name def.name in
+ match SimpleNameMap.find_opt sname builtin_globals_map with
+ | Some name ->
+ (* Yes: register the custom binding *)
+ ctx_add decl name ctx
+ | None ->
+ (* Not the case: "standard" registration *)
+ let name = ctx.fmt.global_name def.name in
+ let body = FunId (FromLlbc (FunId (Regular def.body_id), None, None)) in
+ let ctx = ctx_add decl (name ^ "_c") ctx in
+ let ctx = ctx_add body (name ^ "_body") ctx in
+ ctx
+
+let ctx_compute_fun_name (trans_group : pure_fun_translation) (def : fun_decl)
+ (ctx : extraction_ctx) : string =
(* Lookup the LLBC def to compute the region group information *)
let def_id = def.def_id in
- let llbc_def =
- A.FunDeclId.Map.find def_id ctx.trans_ctx.fun_context.fun_decls
- in
+ let llbc_def = A.FunDeclId.Map.find def_id ctx.trans_ctx.fun_ctx.fun_decls in
let sg = llbc_def.signature in
let num_rgs = List.length sg.regions_hierarchy in
- let keep_fwd, (_, backs) = trans_group in
+ let { keep_fwd; fwd = _; backs } = trans_group in
let num_backs = List.length backs in
let rg_info =
match def.back_id with
| None -> None
| Some rg_id ->
let rg = T.RegionGroupId.nth sg.regions_hierarchy rg_id in
- let regions =
+ let region_names =
List.map
- (fun rid -> T.RegionVarId.nth sg.region_params rid)
+ (fun rid -> (T.RegionVarId.nth sg.generics.regions rid).name)
rg.regions
in
- let region_names =
- List.map (fun (r : T.region_var) -> r.name) regions
- in
Some { id = rg_id; region_names }
in
- let is_opaque = def.body = None in
(* Add the function name *)
- let def_name =
- ctx.fmt.fun_name def.basename def.num_loops def.loop_id num_rgs rg_info
- (keep_fwd, num_backs)
- in
- let fun_id = (A.Regular def_id, def.loop_id, def.back_id) in
- let ctx = ctx_add is_opaque (FunId (FromLlbc fun_id)) def_name ctx in
+ ctx.fmt.fun_name def.basename def.num_loops def.loop_id num_rgs rg_info
+ (keep_fwd, num_backs)
+
+(* TODO: move to Extract *)
+let ctx_add_fun_decl (trans_group : pure_fun_translation) (def : fun_decl)
+ (ctx : extraction_ctx) : extraction_ctx =
+ (* Sanity check: the function should not be a global body - those are handled
+ * separately *)
+ assert (not def.is_global_decl_body);
+ (* Lookup the LLBC def to compute the region group information *)
+ let def_id = def.def_id in
+ let { keep_fwd; fwd = _; backs } = trans_group in
+ let num_backs = List.length backs in
+ (* Add the function name *)
+ let def_name = ctx_compute_fun_name trans_group def ctx in
+ let fun_id = (Pure.FunId (Regular def_id), def.loop_id, def.back_id) in
+ let ctx = ctx_add (FunId (FromLlbc fun_id)) def_name ctx in
(* Add the name info *)
{
ctx with
@@ -1039,9 +1238,10 @@ type names_map_init = {
assumed_pure_functions : (pure_assumed_fun_id * string) list;
}
-(** Initialize a names map with a proper set of keywords/names coming from the
+(** Initialize names maps with a proper set of keywords/names coming from the
target language/prover. *)
-let initialize_names_map (fmt : formatter) (init : names_map_init) : names_map =
+let initialize_names_maps (fmt : formatter) (init : names_map_init) : names_maps
+ =
let int_names = List.map fmt.int_name T.all_int_types in
let keywords =
List.concat
@@ -1049,20 +1249,30 @@ let initialize_names_map (fmt : formatter) (init : names_map_init) : names_map =
[ fmt.bool_name; fmt.char_name; fmt.str_name ]; int_names; init.keywords;
]
in
- let names_set = StringSet.of_list keywords in
- let name_to_id =
- StringMap.of_list (List.map (fun x -> (x, UnknownId)) keywords)
- in
- let opaque_ids = IdSet.empty in
+ let names_set = StringSet.empty in
+ let name_to_id = StringMap.empty in
(* We fist initialize [id_to_name] as empty, because the id of a keyword is [UnknownId].
* Also note that we don't need this mapping for keywords: we insert keywords only
* to check collisions. *)
let id_to_name = IdMap.empty in
- let nm = { id_to_name; name_to_id; names_set; opaque_ids } in
+ let names_map = { id_to_name; name_to_id; names_set } in
+ let unsafe_names_map = empty_unsafe_names_map in
+ let strict_names_map = empty_names_map in
(* For debugging - we are creating bindings for assumed types and functions, so
* it is ok if we simply use the "show" function (those aren't simply identified
* by numbers) *)
let id_to_string = show_id in
+ (* Add the keywords as strict collisions *)
+ let strict_names_map =
+ List.fold_left
+ (fun nm name ->
+ (* There is duplication in the keywords so we don't check the collisions
+ while registering them (what is important is that there are no collisions
+ between keywords and user-defined identifiers) *)
+ names_map_add_unchecked UnknownId name nm)
+ strict_names_map keywords
+ in
+ let nm = { names_map; unsafe_names_map; strict_names_map } in
(* Then we add:
* - the assumed types
* - the assumed struct constructors
@@ -1072,37 +1282,31 @@ let initialize_names_map (fmt : formatter) (init : names_map_init) : names_map =
let nm =
List.fold_left
(fun nm (type_id, name) ->
- names_map_add_assumed_type id_to_string type_id name nm)
+ names_maps_add_assumed_type id_to_string type_id name nm)
nm init.assumed_adts
in
let nm =
List.fold_left
(fun nm (type_id, name) ->
- names_map_add_assumed_struct id_to_string type_id name nm)
+ names_maps_add_assumed_struct id_to_string type_id name nm)
nm init.assumed_structs
in
let nm =
List.fold_left
(fun nm (type_id, variant_id, name) ->
- names_map_add_assumed_variant id_to_string type_id variant_id name nm)
+ names_maps_add_assumed_variant id_to_string type_id variant_id name nm)
nm init.assumed_variants
in
let assumed_functions =
List.map
- (fun (fid, rg, name) -> (FromLlbc (A.Assumed fid, None, rg), name))
+ (fun (fid, rg, name) ->
+ (FromLlbc (Pure.FunId (Assumed fid), None, rg), name))
init.assumed_llbc_functions
@ List.map (fun (fid, name) -> (Pure fid, name)) init.assumed_pure_functions
in
let nm =
- (* In practice, the assumed function are opaque. However, assumed functions
- are never grouped in the opaque module, meaning we never need to
- prefix them: we thus consider them as non-opaque with regards to the
- names map.
- *)
- let is_opaque = false in
List.fold_left
- (fun nm (fid, name) ->
- names_map_add_function id_to_string is_opaque fid name nm)
+ (fun nm (fid, name) -> names_maps_add_function id_to_string fid name nm)
nm assumed_functions
in
(* Return *)
@@ -1150,22 +1354,20 @@ let default_fun_suffix (num_loops : int) (loop_id : LoopId.id option)
let rg_suff =
(* TODO: make all the backends match what is done for Lean *)
match rg with
- | None -> (
- match !Config.backend with
- | FStar | Coq | HOL4 -> "_fwd"
- | Lean ->
- (* In order to avoid name conflicts:
- * - if the forward is eliminated, we add the suffix "_fwd" (it won't be used)
- * - otherwise, no suffix (because the backward functions will have a suffix)
- *)
- if num_backs = 1 && not keep_fwd then "_fwd" else "")
+ | None ->
+ if
+ (* In order to avoid name conflicts:
+ * - if the forward is eliminated, we add the suffix "_fwd" (it won't be used)
+ * - otherwise, no suffix (because the backward functions will have a suffix)
+ *)
+ num_backs = 1 && not keep_fwd
+ then "_fwd"
+ else ""
| Some rg ->
assert (num_region_groups > 0 && num_backs > 0);
if num_backs = 1 then
(* Exactly one backward function *)
- match !Config.backend with
- | FStar | Coq | HOL4 -> if not keep_fwd then "_fwd_back" else "_back"
- | Lean -> if not keep_fwd then "" else "_back"
+ if not keep_fwd then "" else "_back"
else if
(* Several region groups/backward functions:
- if all the regions in the group have names, we use those names
diff --git a/compiler/ExtractBuiltin.ml b/compiler/ExtractBuiltin.ml
new file mode 100644
index 00000000..a54ab604
--- /dev/null
+++ b/compiler/ExtractBuiltin.ml
@@ -0,0 +1,648 @@
+(** This file declares external identifiers that we catch to map them to
+ definitions coming from the standard libraries in our backends.
+
+ TODO: there misses trait **implementations**
+ *)
+
+open Names
+open Config
+
+type simple_name = string list [@@deriving show, ord]
+
+let name_to_simple_name (s : name) : simple_name =
+ (* We simply ignore the disambiguators *)
+ List.filter_map (function Ident id -> Some id | Disambiguator _ -> None) s
+
+(** Small helper which cuts a string at the occurrences of "::" *)
+let string_to_simple_name (s : string) : simple_name =
+ (* No function to split by using string separator?? *)
+ let name = String.split_on_char ':' s in
+ List.filter (fun s -> s <> "") name
+
+module SimpleNameOrd = struct
+ type t = simple_name
+
+ let compare = compare_simple_name
+ let to_string = show_simple_name
+ let pp_t = pp_simple_name
+ let show_t = show_simple_name
+end
+
+module SimpleNameMap = Collections.MakeMap (SimpleNameOrd)
+module SimpleNameSet = Collections.MakeSet (SimpleNameOrd)
+
+(** Small utility to memoize some computations *)
+let mk_memoized (f : unit -> 'a) : unit -> 'a =
+ let r = ref None in
+ let g () =
+ match !r with
+ | Some x -> x
+ | None ->
+ let x = f () in
+ r := Some x;
+ x
+ in
+ g
+
+(** Switch between two values depending on the target backend.
+
+ We often compute the same value (typically: a name) if the target
+ is F*, Coq or HOL4, and a different value if the target is Lean.
+ *)
+let backend_choice (fstar_coq_hol4 : 'a) (lean : 'a) : 'a =
+ match !backend with Coq | FStar | HOL4 -> fstar_coq_hol4 | Lean -> lean
+
+let builtin_globals : (string * string) list =
+ [
+ (* Min *)
+ ("core::num::usize::MIN", "core_usize_min");
+ ("core::num::u8::MIN", "core_u8_min");
+ ("core::num::u16::MIN", "core_u16_min");
+ ("core::num::u32::MIN", "core_u32_min");
+ ("core::num::u64::MIN", "core_u64_min");
+ ("core::num::u128::MIN", "core_u128_min");
+ ("core::num::isize::MIN", "core_isize_min");
+ ("core::num::i8::MIN", "core_i8_min");
+ ("core::num::i16::MIN", "core_i16_min");
+ ("core::num::i32::MIN", "core_i32_min");
+ ("core::num::i64::MIN", "core_i64_min");
+ ("core::num::i128::MIN", "core_i128_min");
+ (* Max *)
+ ("core::num::usize::MAX", "core_usize_max");
+ ("core::num::u8::MAX", "core_u8_max");
+ ("core::num::u16::MAX", "core_u16_max");
+ ("core::num::u32::MAX", "core_u32_max");
+ ("core::num::u64::MAX", "core_u64_max");
+ ("core::num::u128::MAX", "core_u128_max");
+ ("core::num::isize::MAX", "core_isize_max");
+ ("core::num::i8::MAX", "core_i8_max");
+ ("core::num::i16::MAX", "core_i16_max");
+ ("core::num::i32::MAX", "core_i32_max");
+ ("core::num::i64::MAX", "core_i64_max");
+ ("core::num::i128::MAX", "core_i128_max");
+ ]
+
+let builtin_globals_map : string SimpleNameMap.t =
+ SimpleNameMap.of_list
+ (List.map (fun (x, y) -> (string_to_simple_name x, y)) builtin_globals)
+
+type builtin_variant_info = { fields : (string * string) list }
+[@@deriving show]
+
+type builtin_enum_variant_info = {
+ rust_variant_name : string;
+ extract_variant_name : string;
+ fields : string list option;
+}
+[@@deriving show]
+
+type builtin_type_body_info =
+ | Struct of string * (string * string) list
+ (* The constructor name and the map for the field names *)
+ | Enum of builtin_enum_variant_info list
+(* For every variant, a map for the field names *)
+[@@deriving show]
+
+type builtin_type_info = {
+ rust_name : string list;
+ extract_name : string;
+ keep_params : bool list option;
+ (** We might want to filter some of the type parameters.
+
+ For instance, `Vec` type takes a type parameter for the allocator,
+ which we want to ignore.
+ *)
+ body_info : builtin_type_body_info option;
+}
+[@@deriving show]
+
+type type_variant_kind =
+ | KOpaque
+ | KStruct of (string * string) list
+ (* TODO: handle the tuple case *)
+ | KEnum (* TODO *)
+
+let mk_struct_constructor (type_name : string) : string =
+ let prefix =
+ match !backend with FStar -> "Mk" | Coq | HOL4 -> "mk" | Lean -> ""
+ in
+ let suffix = match !backend with FStar | Coq | HOL4 -> "" | Lean -> ".mk" in
+ prefix ^ type_name ^ suffix
+
+(** The assumed types.
+
+ The optional list of booleans is filtering information for the type
+ parameters. For instance, in the case of the `Vec` functions, there is
+ a type parameter for the allocator to use, which we want to filter.
+ *)
+let builtin_types () : builtin_type_info list =
+ let mk_type (rust_name : string list) ?(keep_params : bool list option = None)
+ ?(kind : type_variant_kind = KOpaque) () : builtin_type_info =
+ let extract_name =
+ let sep = backend_choice "_" "." in
+ String.concat sep rust_name
+ in
+ let body_info : builtin_type_body_info option =
+ match kind with
+ | KOpaque -> None
+ | KStruct fields ->
+ let fields =
+ List.map
+ (fun (rname, name) ->
+ ( rname,
+ match !backend with
+ | FStar | Lean -> name
+ | Coq | HOL4 -> extract_name ^ "_" ^ name ))
+ fields
+ in
+ let constructor = mk_struct_constructor extract_name in
+ Some (Struct (constructor, fields))
+ | KEnum -> raise (Failure "TODO")
+ in
+ { rust_name; extract_name; keep_params; body_info }
+ in
+
+ [
+ (* Alloc *)
+ mk_type [ "alloc"; "alloc"; "Global" ] ();
+ (* Vec *)
+ mk_type [ "alloc"; "vec"; "Vec" ] ~keep_params:(Some [ true; false ]) ();
+ (* Range *)
+ mk_type
+ [ "core"; "ops"; "range"; "Range" ]
+ ~kind:(KStruct [ ("start", "start"); ("end", "end_") ])
+ ();
+ (* Option
+
+ This one is more custom because we use the standard "option" type from
+ the target backend.
+ *)
+ {
+ rust_name = [ "core"; "option"; "Option" ];
+ extract_name =
+ (match !backend with
+ | Lean -> "Option"
+ | Coq | FStar | HOL4 -> "option");
+ keep_params = None;
+ body_info =
+ Some
+ (Enum
+ [
+ {
+ rust_variant_name = "None";
+ extract_variant_name =
+ (match !backend with
+ | FStar | Coq -> "None"
+ | Lean -> "none"
+ | HOL4 -> "NONE");
+ fields = None;
+ };
+ {
+ rust_variant_name = "Some";
+ extract_variant_name =
+ (match !backend with
+ | FStar | Coq -> "Some"
+ | Lean -> "some"
+ | HOL4 -> "SOME");
+ fields = None;
+ };
+ ]);
+ };
+ ]
+
+let mk_builtin_types_map () =
+ SimpleNameMap.of_list
+ (List.map (fun info -> (info.rust_name, info)) (builtin_types ()))
+
+let builtin_types_map = mk_memoized mk_builtin_types_map
+
+type builtin_fun_info = {
+ rg : Types.RegionGroupId.id option;
+ extract_name : string;
+}
+[@@deriving show]
+
+(** The assumed functions.
+
+ The optional list of booleans is filtering information for the type
+ parameters. For instance, in the case of the `Vec` functions, there is
+ a type parameter for the allocator to use, which we want to filter.
+ *)
+let builtin_funs () :
+ (string list * bool list option * builtin_fun_info list) list =
+ let rg0 = Some Types.RegionGroupId.zero in
+ (* Small utility *)
+ let mk_fun (name : string list) (extract_name : string list option)
+ (filter : bool list option) (with_back : bool) (back_no_suffix : bool) :
+ string list * bool list option * builtin_fun_info list =
+ let extract_name =
+ match extract_name with None -> name | Some name -> name
+ in
+ let basename =
+ match !backend with
+ | FStar | Coq | HOL4 -> String.concat "_" extract_name
+ | Lean -> String.concat "." extract_name
+ in
+ let fwd_suffix = if with_back && back_no_suffix then "_fwd" else "" in
+ let fwd = [ { rg = None; extract_name = basename ^ fwd_suffix } ] in
+ let back_suffix = if with_back && back_no_suffix then "" else "_back" in
+ let back =
+ if with_back then [ { rg = rg0; extract_name = basename ^ back_suffix } ]
+ else []
+ in
+ (name, filter, fwd @ back)
+ in
+ [
+ mk_fun [ "core"; "mem"; "replace" ] None None true false;
+ mk_fun [ "alloc"; "vec"; "Vec"; "new" ] None None false false;
+ mk_fun
+ [ "alloc"; "vec"; "Vec"; "push" ]
+ None
+ (Some [ true; false ])
+ true true;
+ mk_fun
+ [ "alloc"; "vec"; "Vec"; "insert" ]
+ None
+ (Some [ true; false ])
+ true true;
+ mk_fun
+ [ "alloc"; "vec"; "Vec"; "len" ]
+ None
+ (Some [ true; false ])
+ true false;
+ mk_fun
+ [ "alloc"; "vec"; "Vec"; "index" ]
+ None
+ (Some [ true; true; false ])
+ true false;
+ mk_fun
+ [ "alloc"; "vec"; "Vec"; "index_mut" ]
+ None
+ (Some [ true; true; false ])
+ true false;
+ mk_fun
+ [ "alloc"; "boxed"; "Box"; "deref" ]
+ None
+ (Some [ true; false ])
+ true false;
+ mk_fun
+ [ "alloc"; "boxed"; "Box"; "deref_mut" ]
+ None
+ (Some [ true; false ])
+ true false;
+ (* TODO: fix the same like "[T]" below *)
+ mk_fun
+ [ "core"; "slice"; "index"; "[T]"; "index" ]
+ (Some [ "core"; "slice"; "index"; "Slice"; "index" ])
+ None true false;
+ mk_fun
+ [ "core"; "slice"; "index"; "[T]"; "index_mut" ]
+ (Some [ "core"; "slice"; "index"; "Slice"; "index_mut" ])
+ None true false;
+ mk_fun
+ [ "core"; "array"; "[T; N]"; "index" ]
+ (Some [ "core"; "array"; "Array"; "index" ])
+ None true false;
+ mk_fun
+ [ "core"; "array"; "[T; N]"; "index_mut" ]
+ (Some [ "core"; "array"; "Array"; "index_mut" ])
+ None true false;
+ mk_fun [ "core"; "slice"; "index"; "Range"; "get" ] None None true false;
+ mk_fun [ "core"; "slice"; "index"; "Range"; "get_mut" ] None None true false;
+ mk_fun [ "core"; "slice"; "index"; "Range"; "index" ] None None true false;
+ mk_fun
+ [ "core"; "slice"; "index"; "Range"; "index_mut" ]
+ None None true false;
+ mk_fun
+ [ "core"; "slice"; "index"; "Range"; "get_unchecked" ]
+ None None false false;
+ mk_fun
+ [ "core"; "slice"; "index"; "Range"; "get_unchecked_mut" ]
+ None None false false;
+ mk_fun
+ [ "core"; "slice"; "index"; "usize"; "get" ]
+ (Some [ "core"; "slice"; "index"; "Usize"; "get" ])
+ None true false;
+ mk_fun
+ [ "core"; "slice"; "index"; "usize"; "get_mut" ]
+ (Some [ "core"; "slice"; "index"; "Usize"; "get_mut" ])
+ None true false;
+ mk_fun
+ [ "core"; "slice"; "index"; "usize"; "get_unchecked" ]
+ (Some [ "core"; "slice"; "index"; "Usize"; "get_unchecked" ])
+ None false false;
+ mk_fun
+ [ "core"; "slice"; "index"; "usize"; "get_unchecked_mut" ]
+ (Some [ "core"; "slice"; "index"; "Usize"; "get_unchecked_mut" ])
+ None false false;
+ mk_fun
+ [ "core"; "slice"; "index"; "usize"; "index" ]
+ (Some [ "core"; "slice"; "index"; "Usize"; "index" ])
+ None true false;
+ mk_fun
+ [ "core"; "slice"; "index"; "usize"; "index_mut" ]
+ (Some [ "core"; "slice"; "index"; "Usize"; "index_mut" ])
+ None true false;
+ ]
+
+let mk_builtin_funs_map () =
+ SimpleNameMap.of_list
+ (List.map
+ (fun (name, filter, info) -> (name, (filter, info)))
+ (builtin_funs ()))
+
+let builtin_funs_map = mk_memoized mk_builtin_funs_map
+
+type effect_info = { can_fail : bool; stateful : bool }
+
+let builtin_fun_effects =
+ let int_names =
+ [
+ "usize";
+ "u8";
+ "u16";
+ "u32";
+ "u64";
+ "u128";
+ "isize";
+ "i8";
+ "i16";
+ "i32";
+ "i64";
+ "i128";
+ ]
+ in
+ let int_ops =
+ [ "wrapping_add"; "wrapping_sub"; "rotate_left"; "rotate_right" ]
+ in
+ let int_funs =
+ List.map
+ (fun int_name ->
+ List.map (fun op -> "core::num::" ^ int_name ^ "::" ^ op) int_ops)
+ int_names
+ in
+ let int_funs = List.concat int_funs in
+ let no_fail_no_state_funs =
+ [
+ (* TODO: redundancy with the funs information below *)
+ "alloc::vec::Vec::new";
+ "alloc::vec::Vec::len";
+ "alloc::boxed::Box::deref";
+ "alloc::boxed::Box::deref_mut";
+ "core::mem::replace";
+ "core::mem::take";
+ ]
+ @ int_funs
+ in
+ let no_fail_no_state_funs =
+ List.map
+ (fun n -> (n, { can_fail = false; stateful = false }))
+ no_fail_no_state_funs
+ in
+ let no_state_funs =
+ [
+ (* TODO: redundancy with the funs information below *)
+ "alloc::vec::Vec::push";
+ "alloc::vec::Vec::index";
+ "alloc::vec::Vec::index_mut";
+ "alloc::vec::Vec::index_mut_back";
+ ]
+ in
+ let no_state_funs =
+ List.map (fun n -> (n, { can_fail = true; stateful = false })) no_state_funs
+ in
+ no_fail_no_state_funs @ no_state_funs
+
+let builtin_fun_effects_map =
+ SimpleNameMap.of_list
+ (List.map (fun (n, x) -> (string_to_simple_name n, x)) builtin_fun_effects)
+
+type builtin_trait_decl_info = {
+ rust_name : string;
+ extract_name : string;
+ constructor : string;
+ parent_clauses : string list;
+ consts : (string * string) list;
+ types : (string * (string * string list)) list;
+ (** Every type has:
+ - a Rust name
+ - an extraction name
+ - a list of clauses *)
+ methods : (string * builtin_fun_info list) list;
+}
+[@@deriving show]
+
+let builtin_trait_decls_info () =
+ let rg0 = Some Types.RegionGroupId.zero in
+ let mk_trait (rust_name : string list) ?(extract_name : string option = None)
+ ?(parent_clauses : string list = []) ?(types : string list = [])
+ ?(methods : (string * bool) list = []) () : builtin_trait_decl_info =
+ let extract_name =
+ match extract_name with
+ | Some n -> n
+ | None -> (
+ match !backend with
+ | Coq | FStar | HOL4 -> String.concat "_" rust_name
+ | Lean -> String.concat "." rust_name)
+ in
+ let constructor = mk_struct_constructor extract_name in
+ let consts = [] in
+ let types =
+ let mk_type item_name =
+ let type_name =
+ match !backend with
+ | Coq | FStar | HOL4 -> extract_name ^ "_" ^ item_name
+ | Lean -> item_name
+ in
+ let clauses = [] in
+ (item_name, (type_name, clauses))
+ in
+ List.map mk_type types
+ in
+ let methods =
+ let mk_method (item_name, with_back) =
+ (* TODO: factor out with builtin_funs_info *)
+ let basename =
+ match !backend with
+ | Coq | FStar | HOL4 -> extract_name ^ "_" ^ item_name
+ | Lean -> item_name
+ in
+ let back_no_suffix = false in
+ let fwd_suffix = if with_back && back_no_suffix then "_fwd" else "" in
+ let fwd = [ { rg = None; extract_name = basename ^ fwd_suffix } ] in
+ let back_suffix = if with_back && back_no_suffix then "" else "_back" in
+ let back =
+ if with_back then
+ [ { rg = rg0; extract_name = basename ^ back_suffix } ]
+ else []
+ in
+ (item_name, fwd @ back)
+ in
+ List.map mk_method methods
+ in
+ let rust_name = String.concat "::" rust_name in
+ {
+ rust_name;
+ extract_name;
+ constructor;
+ parent_clauses;
+ consts;
+ types;
+ methods;
+ }
+ in
+ [
+ (* Deref *)
+ mk_trait
+ [ "core"; "ops"; "deref"; "Deref" ]
+ ~types:[ "Target" ]
+ ~methods:[ ("deref", true) ]
+ ();
+ (* DerefMut *)
+ mk_trait
+ [ "core"; "ops"; "deref"; "DerefMut" ]
+ ~parent_clauses:[ backend_choice "deref_inst" "derefInst" ]
+ ~methods:[ ("deref_mut", true) ]
+ ();
+ (* Index *)
+ mk_trait
+ [ "core"; "ops"; "index"; "Index" ]
+ ~types:[ "Output" ]
+ ~methods:[ ("index", true) ]
+ ();
+ (* IndexMut *)
+ mk_trait
+ [ "core"; "ops"; "index"; "IndexMut" ]
+ ~parent_clauses:[ backend_choice "index_inst" "indexInst" ]
+ ~methods:[ ("index_mut", true) ]
+ ();
+ (* Sealed *)
+ mk_trait [ "core"; "slice"; "index"; "private_slice_index"; "Sealed" ] ();
+ (* SliceIndex *)
+ mk_trait
+ [ "core"; "slice"; "index"; "SliceIndex" ]
+ ~parent_clauses:[ backend_choice "sealed_inst" "sealedInst" ]
+ ~types:[ "Output" ]
+ ~methods:
+ [
+ ("get", true);
+ ("get_mut", true);
+ ("get_unchecked", false);
+ ("get_unchecked_mut", false);
+ ("index", true);
+ ("index_mut", true);
+ ]
+ ();
+ ]
+
+let mk_builtin_trait_decls_map () =
+ SimpleNameMap.of_list
+ (List.map
+ (fun info -> (string_to_simple_name info.rust_name, info))
+ (builtin_trait_decls_info ()))
+
+let builtin_trait_decls_map = mk_memoized mk_builtin_trait_decls_map
+
+(* TODO: generalize this.
+
+ For now, the key is:
+ - name of the impl (ex.: "alloc.boxed.Boxed")
+ - name of the implemented trait (ex.: "core.ops.deref.Deref"
+*)
+type simple_name_pair = simple_name * simple_name [@@deriving show, ord]
+
+module SimpleNamePairOrd = struct
+ type t = simple_name_pair
+
+ let compare = compare_simple_name_pair
+ let to_string = show_simple_name_pair
+ let pp_t = pp_simple_name_pair
+ let show_t = show_simple_name_pair
+end
+
+module SimpleNamePairMap = Collections.MakeMap (SimpleNamePairOrd)
+
+let builtin_trait_impls_info () :
+ ((string list * string list) * (bool list option * string)) list =
+ let fmt (type_name : string list)
+ ?(extract_type_name : string list option = None)
+ (trait_name : string list) ?(filter : bool list option = None) () :
+ (string list * string list) * (bool list option * string) =
+ let name =
+ let trait_name = String.concat "" trait_name ^ "Inst" in
+ let sep = backend_choice "_" "." in
+ let type_name =
+ match extract_type_name with
+ | Some type_name -> type_name
+ | None -> type_name
+ in
+ String.concat sep type_name ^ sep ^ trait_name
+ in
+ ((type_name, trait_name), (filter, name))
+ in
+ (* TODO: fix the names like "[T]" below *)
+ [
+ (* core::ops::Deref<alloc::boxed::Box<T>> *)
+ fmt [ "alloc"; "boxed"; "Box" ] [ "core"; "ops"; "deref"; "Deref" ] ();
+ (* core::ops::DerefMut<alloc::boxed::Box<T>> *)
+ fmt [ "alloc"; "boxed"; "Box" ] [ "core"; "ops"; "deref"; "DerefMut" ] ();
+ (* core::ops::index::Index<[T], I> *)
+ fmt
+ [ "core"; "slice"; "index"; "[T]" ]
+ ~extract_type_name:(Some [ "core"; "slice"; "index"; "Slice" ])
+ [ "core"; "ops"; "index"; "Index" ]
+ ();
+ (* core::ops::index::IndexMut<[T], I> *)
+ fmt
+ [ "core"; "slice"; "index"; "[T]" ]
+ ~extract_type_name:(Some [ "core"; "slice"; "index"; "Slice" ])
+ [ "core"; "ops"; "index"; "IndexMut" ]
+ ();
+ (* core::slice::index::private_slice_index::Sealed<Range<usize>> *)
+ fmt
+ [ "core"; "slice"; "index"; "private_slice_index"; "Range" ]
+ [ "core"; "slice"; "index"; "private_slice_index"; "Sealed" ]
+ ();
+ (* core::slice::index::SliceIndex<Range<usize>, [T]> *)
+ fmt
+ [ "core"; "slice"; "index"; "Range" ]
+ [ "core"; "slice"; "index"; "SliceIndex" ]
+ ();
+ (* core::ops::index::Index<[T; N], I> *)
+ fmt
+ [ "core"; "array"; "[T; N]" ]
+ ~extract_type_name:(Some [ "core"; "array"; "Array" ])
+ [ "core"; "ops"; "index"; "Index" ]
+ ();
+ (* core::ops::index::IndexMut<[T; N], I> *)
+ fmt
+ [ "core"; "array"; "[T; N]" ]
+ ~extract_type_name:(Some [ "core"; "array"; "Array" ])
+ [ "core"; "ops"; "index"; "IndexMut" ]
+ ();
+ (* core::slice::index::private_slice_index::Sealed<usize> *)
+ fmt
+ [ "core"; "slice"; "index"; "private_slice_index"; "usize" ]
+ [ "core"; "slice"; "index"; "private_slice_index"; "Sealed" ]
+ ();
+ (* core::slice::index::SliceIndex<usize, [T]> *)
+ fmt
+ [ "core"; "slice"; "index"; "usize" ]
+ [ "core"; "slice"; "index"; "SliceIndex" ]
+ ();
+ (* core::ops::index::Index<Vec<T>, T> *)
+ fmt [ "alloc"; "vec"; "Vec" ]
+ [ "core"; "ops"; "index"; "Index" ]
+ ~filter:(Some [ true; true; false ])
+ ();
+ (* core::ops::index::IndexMut<Vec<T>, T> *)
+ fmt [ "alloc"; "vec"; "Vec" ]
+ [ "core"; "ops"; "index"; "IndexMut" ]
+ ~filter:(Some [ true; true; false ])
+ ();
+ ]
+
+let mk_builtin_trait_impls_map () =
+ SimpleNamePairMap.of_list (builtin_trait_impls_info ())
+
+let builtin_trait_impls_map = mk_memoized mk_builtin_trait_impls_map
diff --git a/compiler/ExtractTypes.ml b/compiler/ExtractTypes.ml
new file mode 100644
index 00000000..77f76bb4
--- /dev/null
+++ b/compiler/ExtractTypes.ml
@@ -0,0 +1,2477 @@
+(** The generic extraction *)
+(* Turn the whole module into a functor: it is very annoying to carry the
+ the formatter everywhere...
+*)
+
+open Pure
+open PureUtils
+open TranslateCore
+open ExtractBase
+open StringUtils
+open Config
+module F = Format
+
+(** Small helper to compute the name of an int type *)
+let int_name (int_ty : integer_type) =
+ let isize, usize, i_format, u_format =
+ match !backend with
+ | FStar | Coq | HOL4 ->
+ ("isize", "usize", format_of_string "i%d", format_of_string "u%d")
+ | Lean -> ("Isize", "Usize", format_of_string "I%d", format_of_string "U%d")
+ in
+ match int_ty with
+ | Isize -> isize
+ | I8 -> Printf.sprintf i_format 8
+ | I16 -> Printf.sprintf i_format 16
+ | I32 -> Printf.sprintf i_format 32
+ | I64 -> Printf.sprintf i_format 64
+ | I128 -> Printf.sprintf i_format 128
+ | Usize -> usize
+ | U8 -> Printf.sprintf u_format 8
+ | U16 -> Printf.sprintf u_format 16
+ | U32 -> Printf.sprintf u_format 32
+ | U64 -> Printf.sprintf u_format 64
+ | U128 -> Printf.sprintf u_format 128
+
+(** Small helper to compute the name of a unary operation *)
+let unop_name (unop : unop) : string =
+ match unop with
+ | Not -> (
+ match !backend with FStar | Lean -> "not" | Coq -> "negb" | HOL4 -> "~")
+ | Neg (int_ty : integer_type) -> (
+ match !backend with Lean -> "-" | _ -> int_name int_ty ^ "_neg")
+ | Cast _ ->
+ (* We never directly use the unop name in this case *)
+ raise (Failure "Unsupported")
+
+(** Small helper to compute the name of a binary operation (note that many
+ binary operations like "less than" are extracted to primitive operations,
+ like [<]).
+ *)
+let named_binop_name (binop : E.binop) (int_ty : integer_type) : string =
+ let binop =
+ match binop with
+ | Div -> "div"
+ | Rem -> "rem"
+ | Add -> "add"
+ | Sub -> "sub"
+ | Mul -> "mul"
+ | Lt -> "lt"
+ | Le -> "le"
+ | Ge -> "ge"
+ | Gt -> "gt"
+ | BitXor -> "xor"
+ | BitAnd -> "and"
+ | BitOr -> "or"
+ | Shl -> "lsl"
+ | Shr ->
+ "asr"
+ (* NOTE: make sure arithmetic shift right is implemented, i.e. OCaml's asr operator, not lsr *)
+ | _ -> raise (Failure "Unreachable")
+ in
+ (* Remark: the Lean case is actually not used *)
+ match !backend with
+ | Lean -> int_name int_ty ^ "." ^ binop
+ | FStar | Coq | HOL4 -> int_name int_ty ^ "_" ^ binop
+
+(** A list of keywords/identifiers used by the backend and with which we
+ want to check collision.
+
+ Remark: this is useful mostly to look for collisions when generating
+ names for *variables*.
+ *)
+let keywords () =
+ let named_unops =
+ unop_name Not
+ :: List.map (fun it -> unop_name (Neg it)) T.all_signed_int_types
+ in
+ let named_binops = [ E.Div; Rem; Add; Sub; Mul ] in
+ let named_binops =
+ List.concat_map
+ (fun bn -> List.map (fun it -> named_binop_name bn it) T.all_int_types)
+ named_binops
+ in
+ let misc =
+ match !backend with
+ | FStar ->
+ [
+ "assert";
+ "assert_norm";
+ "assume";
+ "else";
+ "fun";
+ "fn";
+ "FStar";
+ "FStar.Mul";
+ "if";
+ "in";
+ "include";
+ "int";
+ "let";
+ "list";
+ "match";
+ "open";
+ "rec";
+ "scalar_cast";
+ "then";
+ "type";
+ "Type0";
+ "Type";
+ "unit";
+ "val";
+ "with";
+ ]
+ | Coq ->
+ [
+ "assert";
+ "Arguments";
+ "Axiom";
+ "char_of_byte";
+ "Check";
+ "Declare";
+ "Definition";
+ "else";
+ "End";
+ "fun";
+ "Fixpoint";
+ "if";
+ "in";
+ "int";
+ "Inductive";
+ "Import";
+ "let";
+ "Lemma";
+ "match";
+ "Module";
+ "not";
+ "Notation";
+ "Proof";
+ "Qed";
+ "rec";
+ "Record";
+ "Require";
+ "Scope";
+ "Search";
+ "SearchPattern";
+ "Set";
+ "then";
+ (* [tt] is unit *)
+ "tt";
+ "type";
+ "Type";
+ "unit";
+ "with";
+ ]
+ | Lean ->
+ [
+ "by";
+ "class";
+ "decreasing_by";
+ "def";
+ "deriving";
+ "do";
+ "else";
+ "end";
+ "for";
+ "have";
+ "if";
+ "inductive";
+ "instance";
+ "import";
+ "let";
+ "macro";
+ "match";
+ "namespace";
+ "opaque";
+ "open";
+ "run_cmd";
+ "set_option";
+ "simp";
+ "structure";
+ "syntax";
+ "termination_by";
+ "then";
+ "Type";
+ "unsafe";
+ "where";
+ "with";
+ "opaque_defs";
+ ]
+ | HOL4 ->
+ [
+ "Axiom";
+ "case";
+ "Definition";
+ "else";
+ "End";
+ "fix";
+ "fix_exec";
+ "fn";
+ "fun";
+ "if";
+ "in";
+ "int";
+ "Inductive";
+ "let";
+ "of";
+ "Proof";
+ "QED";
+ "then";
+ "Theorem";
+ ]
+ in
+ List.concat [ named_unops; named_binops; misc ]
+
+let assumed_adts () : (assumed_ty * string) list =
+ match !backend with
+ | Lean ->
+ [
+ (State, "State");
+ (Result, "Result");
+ (Error, "Error");
+ (Fuel, "Nat");
+ (Array, "Array");
+ (Slice, "Slice");
+ (Str, "Str");
+ (RawPtr Mut, "MutRawPtr");
+ (RawPtr Const, "ConstRawPtr");
+ ]
+ | Coq | FStar | HOL4 ->
+ [
+ (State, "state");
+ (Result, "result");
+ (Error, "error");
+ (Fuel, if !backend = HOL4 then "num" else "nat");
+ (Array, "array");
+ (Slice, "slice");
+ (Str, "str");
+ (RawPtr Mut, "mut_raw_ptr");
+ (RawPtr Const, "const_raw_ptr");
+ ]
+
+let assumed_struct_constructors () : (assumed_ty * string) list =
+ match !backend with
+ | Lean -> [ (Array, "Array.make") ]
+ | Coq -> [ (Array, "mk_array") ]
+ | FStar -> [ (Array, "mk_array") ]
+ | HOL4 -> [ (Array, "mk_array") ]
+
+let assumed_variants () : (assumed_ty * VariantId.id * string) list =
+ match !backend with
+ | FStar ->
+ [
+ (Result, result_return_id, "Return");
+ (Result, result_fail_id, "Fail");
+ (Error, error_failure_id, "Failure");
+ (Error, error_out_of_fuel_id, "OutOfFuel");
+ (* No Fuel::Zero on purpose *)
+ (* No Fuel::Succ on purpose *)
+ ]
+ | Coq ->
+ [
+ (Result, result_return_id, "Return");
+ (Result, result_fail_id, "Fail_");
+ (Error, error_failure_id, "Failure");
+ (Error, error_out_of_fuel_id, "OutOfFuel");
+ (Fuel, fuel_zero_id, "O");
+ (Fuel, fuel_succ_id, "S");
+ ]
+ | Lean ->
+ [
+ (Result, result_return_id, "ret");
+ (Result, result_fail_id, "fail");
+ (Error, error_failure_id, "panic");
+ (* No Fuel::Zero on purpose *)
+ (* No Fuel::Succ on purpose *)
+ ]
+ | HOL4 ->
+ [
+ (Result, result_return_id, "Return");
+ (Result, result_fail_id, "Fail");
+ (Error, error_failure_id, "Failure");
+ (* No Fuel::Zero on purpose *)
+ (* No Fuel::Succ on purpose *)
+ ]
+
+let assumed_llbc_functions () :
+ (A.assumed_fun_id * T.RegionGroupId.id option * string) list =
+ let rg0 = Some T.RegionGroupId.zero in
+ match !backend with
+ | FStar | Coq | HOL4 ->
+ [
+ (ArrayIndexShared, None, "array_index_usize");
+ (ArrayIndexMut, None, "array_index_usize");
+ (ArrayIndexMut, rg0, "array_update_usize");
+ (ArrayToSliceShared, None, "array_to_slice");
+ (ArrayToSliceMut, None, "array_to_slice");
+ (ArrayToSliceMut, rg0, "array_from_slice");
+ (ArrayRepeat, None, "array_repeat");
+ (SliceIndexShared, None, "slice_index_usize");
+ (SliceIndexMut, None, "slice_index_usize");
+ (SliceIndexMut, rg0, "slice_update_usize");
+ (SliceLen, None, "slice_len");
+ ]
+ | Lean ->
+ [
+ (ArrayIndexShared, None, "Array.index_usize");
+ (ArrayIndexMut, None, "Array.index_usize");
+ (ArrayIndexMut, rg0, "Array.update_usize");
+ (ArrayToSliceShared, None, "Array.to_slice");
+ (ArrayToSliceMut, None, "Array.to_slice");
+ (ArrayToSliceMut, rg0, "Array.from_slice");
+ (ArrayRepeat, None, "Array.repeat");
+ (SliceIndexShared, None, "Slice.index_usize");
+ (SliceIndexMut, None, "Slice.index_usize");
+ (SliceIndexMut, rg0, "Slice.update_usize");
+ (SliceLen, None, "Slice.len");
+ ]
+
+let assumed_pure_functions () : (pure_assumed_fun_id * string) list =
+ match !backend with
+ | FStar ->
+ [
+ (Return, "return");
+ (Fail, "fail");
+ (Assert, "massert");
+ (FuelDecrease, "decrease");
+ (FuelEqZero, "is_zero");
+ ]
+ | Coq ->
+ (* We don't provide [FuelDecrease] and [FuelEqZero] on purpose *)
+ [ (Return, "return_"); (Fail, "fail_"); (Assert, "massert") ]
+ | Lean ->
+ (* We don't provide [FuelDecrease] and [FuelEqZero] on purpose *)
+ [ (Return, "return"); (Fail, "fail_"); (Assert, "massert") ]
+ | HOL4 ->
+ (* We don't provide [FuelDecrease] and [FuelEqZero] on purpose *)
+ [ (Return, "return"); (Fail, "fail"); (Assert, "massert") ]
+
+let names_map_init () : names_map_init =
+ {
+ keywords = keywords ();
+ assumed_adts = assumed_adts ();
+ assumed_structs = assumed_struct_constructors ();
+ assumed_variants = assumed_variants ();
+ assumed_llbc_functions = assumed_llbc_functions ();
+ assumed_pure_functions = assumed_pure_functions ();
+ }
+
+let extract_unop (extract_expr : bool -> texpression -> unit)
+ (fmt : F.formatter) (inside : bool) (unop : unop) (arg : texpression) : unit
+ =
+ match unop with
+ | Not | Neg _ ->
+ let unop = unop_name unop in
+ if inside then F.pp_print_string fmt "(";
+ F.pp_print_string fmt unop;
+ F.pp_print_space fmt ();
+ extract_expr true arg;
+ if inside then F.pp_print_string fmt ")"
+ | Cast (src, tgt) -> (
+ (* HOL4 has a special treatment: because it doesn't support dependent
+ types, we don't have a specific operator for the cast *)
+ match !backend with
+ | HOL4 ->
+ (* Casting, say, an u32 to an i32 would be done as follows:
+ {[
+ mk_i32 (u32_to_int x)
+ ]}
+ *)
+ if inside then F.pp_print_string fmt "(";
+ F.pp_print_string fmt ("mk_" ^ int_name tgt);
+ F.pp_print_space fmt ();
+ F.pp_print_string fmt "(";
+ F.pp_print_string fmt (int_name src ^ "_to_int");
+ F.pp_print_space fmt ();
+ extract_expr true arg;
+ F.pp_print_string fmt ")";
+ if inside then F.pp_print_string fmt ")"
+ | FStar | Coq | Lean ->
+ (* Rem.: the source type is an implicit parameter *)
+ if inside then F.pp_print_string fmt "(";
+ let cast_str =
+ match !backend with
+ | Coq | FStar -> "scalar_cast"
+ | Lean -> (* TODO: I8.cast, I16.cast, etc.*) "Scalar.cast"
+ | HOL4 -> raise (Failure "Unreachable")
+ in
+ F.pp_print_string fmt cast_str;
+ F.pp_print_space fmt ();
+ if !backend <> Lean then (
+ F.pp_print_string fmt
+ (StringUtils.capitalize_first_letter
+ (PrintPure.integer_type_to_string src));
+ F.pp_print_space fmt ());
+ if !backend = Lean then F.pp_print_string fmt ("." ^ int_name tgt)
+ else
+ F.pp_print_string fmt
+ (StringUtils.capitalize_first_letter
+ (PrintPure.integer_type_to_string tgt));
+ F.pp_print_space fmt ();
+ extract_expr true arg;
+ if inside then F.pp_print_string fmt ")")
+
+(** [extract_expr] : the boolean argument is [inside] *)
+let extract_binop (extract_expr : bool -> texpression -> unit)
+ (fmt : F.formatter) (inside : bool) (binop : E.binop)
+ (int_ty : integer_type) (arg0 : texpression) (arg1 : texpression) : unit =
+ if inside then F.pp_print_string fmt "(";
+ (* Some binary operations have a special notation depending on the backend *)
+ (match (!backend, binop) with
+ | HOL4, (Eq | Ne)
+ | (FStar | Coq | Lean), (Eq | Lt | Le | Ne | Ge | Gt)
+ | Lean, (Div | Rem | Add | Sub | Mul) ->
+ let binop =
+ match binop with
+ | Eq -> "="
+ | Lt -> "<"
+ | Le -> "<="
+ | Ne -> if !backend = Lean then "!=" else "<>"
+ | Ge -> ">="
+ | Gt -> ">"
+ | Div -> "/"
+ | Rem -> "%"
+ | Add -> "+"
+ | Sub -> "-"
+ | Mul -> "*"
+ | _ -> raise (Failure "Unreachable")
+ in
+ let binop =
+ match !backend with FStar | Lean | HOL4 -> binop | Coq -> "s" ^ binop
+ in
+ extract_expr false arg0;
+ F.pp_print_space fmt ();
+ F.pp_print_string fmt binop;
+ F.pp_print_space fmt ();
+ extract_expr false arg1
+ | _ ->
+ let binop = named_binop_name binop int_ty in
+ F.pp_print_string fmt binop;
+ F.pp_print_space fmt ();
+ extract_expr true arg0;
+ F.pp_print_space fmt ();
+ extract_expr true arg1);
+ if inside then F.pp_print_string fmt ")"
+
+let type_decl_kind_to_qualif (kind : decl_kind)
+ (type_kind : type_decl_kind option) : string option =
+ match !backend with
+ | FStar -> (
+ match kind with
+ | SingleNonRec -> Some "type"
+ | SingleRec -> Some "type"
+ | MutRecFirst -> Some "type"
+ | MutRecInner -> Some "and"
+ | MutRecLast -> Some "and"
+ | Assumed -> Some "assume type"
+ | Declared -> Some "val")
+ | Coq -> (
+ match (kind, type_kind) with
+ | SingleNonRec, Some Enum -> Some "Inductive"
+ | SingleNonRec, Some Struct -> Some "Record"
+ | (SingleRec | MutRecFirst), Some _ -> Some "Inductive"
+ | (MutRecInner | MutRecLast), Some _ ->
+ (* Coq doesn't support groups of mutually recursive definitions which mix
+ * records and inducties: we convert everything to records if this happens
+ *)
+ Some "with"
+ | (Assumed | Declared), None -> Some "Axiom"
+ | SingleNonRec, None ->
+ (* This is for traits *)
+ Some "Record"
+ | _ ->
+ raise
+ (Failure
+ ("Unexpected: (" ^ show_decl_kind kind ^ ", "
+ ^ Print.option_to_string show_type_decl_kind type_kind
+ ^ ")")))
+ | Lean -> (
+ match kind with
+ | SingleNonRec ->
+ if type_kind = Some Struct then Some "structure" else Some "inductive"
+ | SingleRec -> Some "inductive"
+ | MutRecFirst -> Some "inductive"
+ | MutRecInner -> Some "inductive"
+ | MutRecLast -> Some "inductive"
+ | Assumed -> Some "axiom"
+ | Declared -> Some "axiom")
+ | HOL4 -> None
+
+let fun_decl_kind_to_qualif (kind : decl_kind) : string option =
+ match !backend with
+ | FStar -> (
+ match kind with
+ | SingleNonRec -> Some "let"
+ | SingleRec -> Some "let rec"
+ | MutRecFirst -> Some "let rec"
+ | MutRecInner -> Some "and"
+ | MutRecLast -> Some "and"
+ | Assumed -> Some "assume val"
+ | Declared -> Some "val")
+ | Coq -> (
+ match kind with
+ | SingleNonRec -> Some "Definition"
+ | SingleRec -> Some "Fixpoint"
+ | MutRecFirst -> Some "Fixpoint"
+ | MutRecInner -> Some "with"
+ | MutRecLast -> Some "with"
+ | Assumed -> Some "Axiom"
+ | Declared -> Some "Axiom")
+ | Lean -> (
+ match kind with
+ | SingleNonRec -> Some "def"
+ | SingleRec -> Some "divergent def"
+ | MutRecFirst -> Some "mutual divergent def"
+ | MutRecInner -> Some "divergent def"
+ | MutRecLast -> Some "divergent def"
+ | Assumed -> Some "axiom"
+ | Declared -> Some "axiom")
+ | HOL4 -> None
+
+(** The type of types.
+
+ TODO: move inside the formatter?
+ *)
+let type_keyword () =
+ match !backend with
+ | FStar -> "Type0"
+ | Coq | Lean -> "Type"
+ | HOL4 -> raise (Failure "Unexpected")
+
+(**
+ [ctx]: we use the context to lookup type definitions, to retrieve type names.
+ This is used to compute variable names, when they have no basenames: in this
+ case we use the first letter of the type name.
+
+ [variant_concatenate_type_name]: if true, add the type name as a prefix
+ to the variant names.
+ Ex.:
+ In Rust:
+ {[
+ enum List = {
+ Cons(u32, Box<List>),x
+ Nil,
+ }
+ ]}
+
+ F*, if option activated:
+ {[
+ type list =
+ | ListCons : u32 -> list -> list
+ | ListNil : list
+ ]}
+
+ F*, if option not activated:
+ {[
+ type list =
+ | Cons : u32 -> list -> list
+ | Nil : list
+ ]}
+
+ Rk.: this should be true by default, because in Rust all the variant names
+ are actively uniquely identifier by the type name [List::Cons(...)], while
+ in other languages it is not necessarily the case, and thus clashes can mess
+ up type checking. Note that some languages actually forbids the name clashes
+ (it is the case of F* ).
+ *)
+let mk_formatter (ctx : trans_ctx) (crate_name : string)
+ (variant_concatenate_type_name : bool) : formatter =
+ let int_name = int_name in
+
+ (* Prepare a name.
+ * The first id elem is always the crate: if it is the local crate,
+ * we remove it.
+ * We also remove all the disambiguators, then convert everything to strings.
+ * **Rmk:** because we remove the disambiguators, there may be name collisions
+ * (which is ok, because we check for name collisions and fail if there is any).
+ *)
+ let get_name (name : name) : string list =
+ (* Rmk.: initially we only filtered the disambiguators equal to 0 *)
+ let name = Names.filter_disambiguators name in
+ match name with
+ | Ident crate :: name ->
+ let name = if crate = crate_name then name else Ident crate :: name in
+ let name =
+ List.map
+ (function
+ | Names.Ident s -> s
+ | Disambiguator d -> Names.Disambiguator.to_string d)
+ name
+ in
+ name
+ | _ ->
+ raise (Failure ("Unexpected name shape: " ^ Print.name_to_string name))
+ in
+ let flatten_name (name : string list) : string =
+ match !backend with
+ | FStar | Coq | HOL4 -> String.concat "_" name
+ | Lean -> String.concat "." name
+ in
+ let get_type_name = get_name in
+ let get_type_name_no_suffix name =
+ match !backend with
+ | FStar | Coq | HOL4 -> String.concat "_" (get_type_name name)
+ | Lean -> String.concat "." (get_type_name name)
+ in
+ let type_name name =
+ match !backend with
+ | FStar ->
+ StringUtils.lowercase_first_letter (get_type_name_no_suffix name ^ "_t")
+ | Coq | HOL4 -> get_type_name_no_suffix name ^ "_t"
+ | Lean -> get_type_name_no_suffix name
+ in
+ let field_name (def_name : name) (field_id : FieldId.id)
+ (field_name : string option) : string =
+ let field_name_s =
+ match field_name with
+ | Some field_name -> field_name
+ | None ->
+ (* TODO: extract structs with no field names to tuples *)
+ FieldId.to_string field_id
+ in
+ if !Config.record_fields_short_names then
+ if field_name = None then (* TODO: this is a bit ugly *)
+ "_" ^ field_name_s
+ else field_name_s
+ else
+ let def_name = get_type_name_no_suffix def_name ^ "_" ^ field_name_s in
+ match !backend with
+ | Lean | HOL4 -> def_name
+ | Coq | FStar -> StringUtils.lowercase_first_letter def_name
+ in
+ let variant_name (def_name : name) (variant : string) : string =
+ match !backend with
+ | FStar | Coq | HOL4 ->
+ let variant = to_camel_case variant in
+ if variant_concatenate_type_name then
+ StringUtils.capitalize_first_letter
+ (get_type_name_no_suffix def_name ^ "_" ^ variant)
+ else variant
+ | Lean -> variant
+ in
+ let struct_constructor (basename : name) : string =
+ let tname = type_name basename in
+ ExtractBuiltin.mk_struct_constructor tname
+ in
+ let get_fun_name fname =
+ let fname = get_name fname in
+ (* TODO: don't convert to snake case for Coq, HOL4, F* *)
+ let fname = flatten_name fname in
+ match !backend with
+ | FStar | Coq | HOL4 -> StringUtils.lowercase_first_letter fname
+ | Lean -> fname
+ in
+ let global_name (name : global_name) : string =
+ (* Converting to snake case also lowercases the letters (in Rust, global
+ * names are written in capital letters). *)
+ let parts = List.map to_snake_case (get_name name) in
+ String.concat "_" parts
+ in
+ let fun_name (fname : fun_name) (num_loops : int) (loop_id : LoopId.id option)
+ (num_rgs : int) (rg : region_group_info option) (filter_info : bool * int)
+ : string =
+ let fname = get_fun_name fname in
+ (* Compute the suffix *)
+ let suffix = default_fun_suffix num_loops loop_id num_rgs rg filter_info in
+ (* Concatenate *)
+ fname ^ suffix
+ in
+
+ let trait_decl_name (trait_decl : trait_decl) : string =
+ type_name trait_decl.name
+ in
+
+ let trait_impl_name (trait_decl : trait_decl) (trait_impl : trait_impl) :
+ string =
+ (* TODO: provisional: we concatenate the trait impl name (which is its type)
+ with the trait decl name *)
+ let trait_decl =
+ let name = trait_decl.name in
+ let name = get_type_name_no_suffix name ^ "Inst" in
+ (* Remove the occurrences of '.' *)
+ String.concat "" (String.split_on_char '.' name)
+ in
+ let name = flatten_name (get_type_name trait_impl.name @ [ trait_decl ]) in
+ match !backend with
+ | FStar -> StringUtils.lowercase_first_letter name
+ | Coq | HOL4 | Lean -> name
+ in
+
+ let trait_decl_constructor (trait_decl : trait_decl) : string =
+ let name = trait_decl_name trait_decl in
+ ExtractBuiltin.mk_struct_constructor name
+ in
+
+ let trait_parent_clause_name (trait_decl : trait_decl) (clause : trait_clause)
+ : string =
+ (* TODO: improve - it would be better to not use indices *)
+ let clause = "parent_clause_" ^ TraitClauseId.to_string clause.clause_id in
+ if !Config.record_fields_short_names then clause
+ else trait_decl_name trait_decl ^ "_" ^ clause
+ in
+ let trait_type_name (trait_decl : trait_decl) (item : string) : string =
+ let name =
+ if !Config.record_fields_short_names then item
+ else trait_decl_name trait_decl ^ "_" ^ item
+ in
+ (* Constants are usually all capital letters.
+ Some backends do not support field names starting with a capital letter,
+ and it may be weird to lowercase everything (especially as it may lead
+ to more name collisions): we add a prefix when necessary.
+ For instance, it gives: "U" -> "tU"
+ Note that for some backends we prepend the type name (because those backends
+ can't disambiguate fields coming from different ADTs if they have the same
+ names), and thus don't need to add a prefix starting with a lowercase.
+ *)
+ match !backend with FStar -> "t" ^ name | Coq | Lean | HOL4 -> name
+ in
+ let trait_const_name (trait_decl : trait_decl) (item : string) : string =
+ let name =
+ if !Config.record_fields_short_names then item
+ else trait_decl_name trait_decl ^ "_" ^ item
+ in
+ (* See [trait_type_name] *)
+ match !backend with FStar -> "c" ^ name | Coq | Lean | HOL4 -> name
+ in
+ let trait_method_name (trait_decl : trait_decl) (item : string) : string =
+ if !Config.record_fields_short_names then item
+ else trait_decl_name trait_decl ^ "_" ^ item
+ in
+ let trait_type_clause_name (trait_decl : trait_decl) (item : string)
+ (clause : trait_clause) : string =
+ (* TODO: improve - it would be better to not use indices *)
+ trait_type_name trait_decl item
+ ^ "_clause_"
+ ^ TraitClauseId.to_string clause.clause_id
+ in
+
+ let termination_measure_name (_fid : A.FunDeclId.id) (fname : fun_name)
+ (num_loops : int) (loop_id : LoopId.id option) : string =
+ let fname = get_fun_name fname in
+ let lp_suffix = default_fun_loop_suffix num_loops loop_id in
+ (* Compute the suffix *)
+ let suffix =
+ match !Config.backend with
+ | FStar -> "_decreases"
+ | Lean -> "_terminates"
+ | Coq | HOL4 -> raise (Failure "Unexpected")
+ in
+ (* Concatenate *)
+ fname ^ lp_suffix ^ suffix
+ in
+
+ let decreases_proof_name (_fid : A.FunDeclId.id) (fname : fun_name)
+ (num_loops : int) (loop_id : LoopId.id option) : string =
+ let fname = get_fun_name fname in
+ let lp_suffix = default_fun_loop_suffix num_loops loop_id in
+ (* Compute the suffix *)
+ let suffix =
+ match !Config.backend with
+ | Lean -> "_decreases"
+ | FStar | Coq | HOL4 -> raise (Failure "Unexpected")
+ in
+ (* Concatenate *)
+ fname ^ lp_suffix ^ suffix
+ in
+
+ let var_basename (_varset : StringSet.t) (basename : string option) (ty : ty)
+ : string =
+ (* Small helper to derive var names from ADT type names.
+
+ We do the following:
+ - convert the type name to snake case
+ - take the first letter of every "letter group"
+ Ex.: "HashMap" -> "hash_map" -> "hm"
+ *)
+ let name_from_type_ident (name : string) : string =
+ let cl = to_snake_case name in
+ let cl = String.split_on_char '_' cl in
+ let cl = List.filter (fun s -> String.length s > 0) cl in
+ assert (List.length cl > 0);
+ let cl = List.map (fun s -> s.[0]) cl in
+ StringUtils.string_of_chars cl
+ in
+ (* If there is a basename, we use it *)
+ match basename with
+ | Some basename ->
+ (* This should be a no-op *)
+ to_snake_case basename
+ | None -> (
+ (* No basename: we use the first letter of the type *)
+ match ty with
+ | Adt (type_id, generics) -> (
+ match type_id with
+ | Tuple ->
+ (* The "pair" case is frequent enough to have its special treatment *)
+ if List.length generics.types = 2 then "p" else "t"
+ | Assumed Result -> "r"
+ | Assumed Error -> ConstStrings.error_basename
+ | Assumed Fuel -> ConstStrings.fuel_basename
+ | Assumed Array -> "a"
+ | Assumed Slice -> "s"
+ | Assumed Str -> "s"
+ | Assumed State -> ConstStrings.state_basename
+ | Assumed (RawPtr _) -> "p"
+ | AdtId adt_id ->
+ let def = TypeDeclId.Map.find adt_id ctx.type_ctx.type_decls in
+ (* Derive the var name from the last ident of the type name
+ * Ex.: ["hashmap"; "HashMap"] ~~> "HashMap" -> "hash_map" -> "hm"
+ *)
+ (* The name shouldn't be empty, and its last element should
+ * be an ident *)
+ let cl = List.nth def.name (List.length def.name - 1) in
+ name_from_type_ident (Names.as_ident cl))
+ | TypeVar _ -> (
+ (* TODO: use "t" also for F* *)
+ match !backend with
+ | FStar -> "x" (* lacking inspiration here... *)
+ | Coq | Lean | HOL4 -> "t" (* lacking inspiration here... *))
+ | Literal lty -> (
+ match lty with Bool -> "b" | Char -> "c" | Integer _ -> "i")
+ | Arrow _ -> "f"
+ | TraitType (_, _, name) -> name_from_type_ident name)
+ in
+ let type_var_basename (_varset : StringSet.t) (basename : string) : string =
+ (* Rust type variables are snake-case and start with a capital letter *)
+ match !backend with
+ | FStar ->
+ (* This is *not* a no-op: this removes the capital letter *)
+ to_snake_case basename
+ | HOL4 ->
+ (* In HOL4, type variable names must start with "'" *)
+ "'" ^ to_snake_case basename
+ | Coq | Lean -> basename
+ in
+ let const_generic_var_basename (_varset : StringSet.t) (basename : string) :
+ string =
+ (* Rust type variables are snake-case and start with a capital letter *)
+ match !backend with
+ | FStar | HOL4 ->
+ (* This is *not* a no-op: this removes the capital letter *)
+ to_snake_case basename
+ | Coq | Lean -> basename
+ in
+ let trait_clause_basename (_varset : StringSet.t) (_clause : trait_clause) :
+ string =
+ (* TODO: actually use the clause to derive the name *)
+ "inst"
+ in
+ let trait_self_clause_basename = "self_clause" in
+ let append_index (basename : string) (i : int) : string =
+ basename ^ string_of_int i
+ in
+
+ let extract_literal (fmt : F.formatter) (inside : bool) (cv : literal) : unit
+ =
+ match cv with
+ | Scalar sv -> (
+ match !backend with
+ | FStar -> F.pp_print_string fmt (Z.to_string sv.PV.value)
+ | Coq | HOL4 | Lean ->
+ let print_brackets = inside && !backend = HOL4 in
+ if print_brackets then F.pp_print_string fmt "(";
+ (match !backend with
+ | Coq | Lean -> ()
+ | HOL4 ->
+ F.pp_print_string fmt ("int_to_" ^ int_name sv.PV.int_ty);
+ F.pp_print_space fmt ()
+ | _ -> raise (Failure "Unreachable"));
+ (* We need to add parentheses if the value is negative *)
+ if sv.PV.value >= Z.of_int 0 then
+ F.pp_print_string fmt (Z.to_string sv.PV.value)
+ else if !backend = Lean then
+ (* TODO: parsing issues with Lean because there are ambiguous
+ interpretations between int values and nat values *)
+ F.pp_print_string fmt
+ ("(-(" ^ Z.to_string (Z.neg sv.PV.value) ^ ":Int))")
+ else F.pp_print_string fmt ("(" ^ Z.to_string sv.PV.value ^ ")");
+ (match !backend with
+ | Coq ->
+ let iname = int_name sv.PV.int_ty in
+ F.pp_print_string fmt ("%" ^ iname)
+ | Lean ->
+ let iname = String.lowercase_ascii (int_name sv.PV.int_ty) in
+ F.pp_print_string fmt ("#" ^ iname)
+ | HOL4 -> ()
+ | _ -> raise (Failure "Unreachable"));
+ if print_brackets then F.pp_print_string fmt ")")
+ | Bool b ->
+ let b =
+ match !backend with
+ | HOL4 -> if b then "T" else "F"
+ | Coq | FStar | Lean -> if b then "true" else "false"
+ in
+ F.pp_print_string fmt b
+ | Char c -> (
+ match !backend with
+ | HOL4 ->
+ (* [#"a"] is a notation for [CHR 97] (97 is the ASCII code for 'a') *)
+ F.pp_print_string fmt ("#\"" ^ String.make 1 c ^ "\"")
+ | FStar | Lean -> F.pp_print_string fmt ("'" ^ String.make 1 c ^ "'")
+ | Coq ->
+ if inside then F.pp_print_string fmt "(";
+ F.pp_print_string fmt "char_of_byte";
+ F.pp_print_space fmt ();
+ (* Convert the the char to ascii *)
+ let c =
+ let i = Char.code c in
+ let x0 = i / 16 in
+ let x1 = i mod 16 in
+ "Coq.Init.Byte.x" ^ string_of_int x0 ^ string_of_int x1
+ in
+ F.pp_print_string fmt c;
+ if inside then F.pp_print_string fmt ")")
+ in
+ let bool_name = if !backend = Lean then "Bool" else "bool" in
+ let char_name = if !backend = Lean then "Char" else "char" in
+ let str_name = if !backend = Lean then "String" else "string" in
+ {
+ bool_name;
+ char_name;
+ int_name;
+ str_name;
+ type_decl_kind_to_qualif;
+ fun_decl_kind_to_qualif;
+ field_name;
+ variant_name;
+ struct_constructor;
+ type_name;
+ global_name;
+ fun_name;
+ termination_measure_name;
+ decreases_proof_name;
+ trait_decl_name;
+ trait_impl_name;
+ trait_decl_constructor;
+ trait_parent_clause_name;
+ trait_const_name;
+ trait_type_name;
+ trait_method_name;
+ trait_type_clause_name;
+ var_basename;
+ type_var_basename;
+ const_generic_var_basename;
+ trait_self_clause_basename;
+ trait_clause_basename;
+ append_index;
+ extract_literal;
+ extract_unop;
+ extract_binop;
+ }
+
+let mk_formatter_and_names_maps (ctx : trans_ctx) (crate_name : string)
+ (variant_concatenate_type_name : bool) : formatter * names_maps =
+ let fmt = mk_formatter ctx crate_name variant_concatenate_type_name in
+ let names_maps = initialize_names_maps fmt (names_map_init ()) in
+ (fmt, names_maps)
+
+let is_single_opaque_fun_decl_group (dg : Pure.fun_decl list) : bool =
+ match dg with [ d ] -> d.body = None | _ -> false
+
+let is_single_opaque_type_decl_group (dg : Pure.type_decl list) : bool =
+ match dg with [ d ] -> d.kind = Opaque | _ -> false
+
+let is_empty_record_type_decl (d : Pure.type_decl) : bool = d.kind = Struct []
+
+let is_empty_record_type_decl_group (dg : Pure.type_decl list) : bool =
+ match dg with [ d ] -> is_empty_record_type_decl d | _ -> false
+
+(** In some provers, groups of definitions must be delimited.
+
+ - in Coq, *every* group (including singletons) must end with "."
+ - in Lean, groups of mutually recursive definitions must end with "end"
+ - in HOL4 (in most situations) the whole group must be within a `Define` command
+
+ Calls to {!extract_fun_decl} should be inserted between calls to
+ {!start_fun_decl_group} and {!end_fun_decl_group}.
+
+ TODO: maybe those [{start/end}_decl_group] functions are not that much a good
+ idea and we should merge them with the corresponding [extract_decl] functions.
+ *)
+let start_fun_decl_group (ctx : extraction_ctx) (fmt : F.formatter)
+ (is_rec : bool) (dg : Pure.fun_decl list) =
+ match !backend with
+ | FStar | Coq | Lean -> ()
+ | HOL4 ->
+ (* In HOL4, opaque functions have a special treatment *)
+ if is_single_opaque_fun_decl_group dg then ()
+ else
+ let compute_fun_def_name (def : Pure.fun_decl) : string =
+ ctx_get_local_function def.def_id def.loop_id def.back_id ctx ^ "_def"
+ in
+ let names = List.map compute_fun_def_name dg in
+ (* Add a break before *)
+ F.pp_print_break fmt 0 0;
+ (* Open the box for the delimiters *)
+ F.pp_open_vbox fmt 0;
+ (* Open the box for the definitions themselves *)
+ F.pp_open_vbox fmt ctx.indent_incr;
+ (* Print the delimiters *)
+ if is_rec then
+ F.pp_print_string fmt
+ ("val [" ^ String.concat ", " names ^ "] = DefineDiv ‘")
+ else (
+ assert (List.length names = 1);
+ let name = List.hd names in
+ F.pp_print_string fmt ("val " ^ name ^ " = Define ‘"));
+ F.pp_print_cut fmt ()
+
+(** See {!start_fun_decl_group}. *)
+let end_fun_decl_group (fmt : F.formatter) (is_rec : bool)
+ (dg : Pure.fun_decl list) =
+ match !backend with
+ | FStar -> ()
+ | Coq ->
+ (* For aesthetic reasons, we print the Coq end group delimiter directly
+ in {!extract_fun_decl}. *)
+ ()
+ | Lean ->
+ (* We must add the "end" keyword to groups of mutually recursive functions *)
+ if is_rec && List.length dg > 1 then (
+ F.pp_print_cut fmt ();
+ F.pp_print_string fmt "end";
+ (* Add breaks to insert new lines between definitions *)
+ F.pp_print_break fmt 0 0)
+ else ()
+ | HOL4 ->
+ (* In HOL4, opaque functions have a special treatment *)
+ if is_single_opaque_fun_decl_group dg then ()
+ else (
+ (* Close the box for the definitions *)
+ F.pp_close_box fmt ();
+ (* Print the end delimiter *)
+ F.pp_print_cut fmt ();
+ F.pp_print_string fmt "’";
+ (* Close the box for the delimiters *)
+ F.pp_close_box fmt ();
+ (* Add breaks to insert new lines between definitions *)
+ F.pp_print_break fmt 0 0)
+
+(** See {!start_fun_decl_group}: similar usage, but for the type declarations. *)
+let start_type_decl_group (ctx : extraction_ctx) (fmt : F.formatter)
+ (is_rec : bool) (dg : Pure.type_decl list) =
+ match !backend with
+ | FStar | Coq -> ()
+ | Lean ->
+ if is_rec && List.length dg > 1 then (
+ F.pp_print_space fmt ();
+ F.pp_print_string fmt "mutual";
+ F.pp_print_space fmt ())
+ | HOL4 ->
+ (* In HOL4, opaque types and empty records have a special treatment *)
+ if
+ is_single_opaque_type_decl_group dg
+ || is_empty_record_type_decl_group dg
+ then ()
+ else (
+ (* Add a break before *)
+ F.pp_print_break fmt 0 0;
+ (* Open the box for the delimiters *)
+ F.pp_open_vbox fmt 0;
+ (* Open the box for the definitions themselves *)
+ F.pp_open_vbox fmt ctx.indent_incr;
+ (* Print the delimiters *)
+ F.pp_print_string fmt "Datatype:";
+ F.pp_print_cut fmt ())
+
+(** See {!start_fun_decl_group}. *)
+let end_type_decl_group (fmt : F.formatter) (is_rec : bool)
+ (dg : Pure.type_decl list) =
+ match !backend with
+ | FStar -> ()
+ | Coq ->
+ (* For aesthetic reasons, we print the Coq end group delimiter directly
+ in {!extract_fun_decl}. *)
+ ()
+ | Lean ->
+ (* We must add the "end" keyword to groups of mutually recursive functions *)
+ if is_rec && List.length dg > 1 then (
+ F.pp_print_cut fmt ();
+ F.pp_print_string fmt "end";
+ (* Add breaks to insert new lines between definitions *)
+ F.pp_print_break fmt 0 0)
+ else ()
+ | HOL4 ->
+ (* In HOL4, opaque types and empty records have a special treatment *)
+ if
+ is_single_opaque_type_decl_group dg
+ || is_empty_record_type_decl_group dg
+ then ()
+ else (
+ (* Close the box for the definitions *)
+ F.pp_close_box fmt ();
+ (* Print the end delimiter *)
+ F.pp_print_cut fmt ();
+ F.pp_print_string fmt "End";
+ (* Close the box for the delimiters *)
+ F.pp_close_box fmt ();
+ (* Add breaks to insert new lines between definitions *)
+ F.pp_print_break fmt 0 0)
+
+let unit_name () =
+ match !backend with Lean -> "Unit" | Coq | FStar | HOL4 -> "unit"
+
+(** Small helper *)
+let extract_arrow (fmt : F.formatter) () : unit =
+ if !Config.backend = Lean then F.pp_print_string fmt "→"
+ else F.pp_print_string fmt "->"
+
+let extract_const_generic (ctx : extraction_ctx) (fmt : F.formatter)
+ (inside : bool) (cg : const_generic) : unit =
+ match cg with
+ | ConstGenericGlobal id ->
+ let s = ctx_get_global id ctx in
+ F.pp_print_string fmt s
+ | ConstGenericValue v -> ctx.fmt.extract_literal fmt inside v
+ | ConstGenericVar id ->
+ let s = ctx_get_const_generic_var id ctx in
+ F.pp_print_string fmt s
+
+let extract_literal_type (ctx : extraction_ctx) (fmt : F.formatter)
+ (ty : literal_type) : unit =
+ match ty with
+ | Bool -> F.pp_print_string fmt ctx.fmt.bool_name
+ | Char -> F.pp_print_string fmt ctx.fmt.char_name
+ | Integer int_ty -> F.pp_print_string fmt (ctx.fmt.int_name int_ty)
+
+(** [inside] constrols whether we should add parentheses or not around type
+ applications (if [true] we add parentheses).
+
+ [no_params_tys]: for all the types inside this set, do not print the type parameters.
+ This is used for HOL4. As polymorphism is uniform in HOL4, printing the
+ type parameters in the recursive definitions is useless (and actually
+ forbidden).
+
+ For instance, where in F* we would write:
+ {[
+ type list a = | Nil : list a | Cons : a -> list a -> list a
+ ]}
+
+ In HOL4 we would simply write:
+ {[
+ Datatype:
+ list = Nil 'a | Cons 'a list
+ End
+ ]}
+ *)
+let rec extract_ty (ctx : extraction_ctx) (fmt : F.formatter)
+ (no_params_tys : TypeDeclId.Set.t) (inside : bool) (ty : ty) : unit =
+ let extract_rec = extract_ty ctx fmt no_params_tys in
+ match ty with
+ | Adt (type_id, generics) -> (
+ let has_params = generics <> empty_generic_args in
+ match type_id with
+ | Tuple ->
+ (* This is a bit annoying, but in F*/Coq/HOL4 [()] is not the unit type:
+ * we have to write [unit]... *)
+ if generics.types = [] then F.pp_print_string fmt (unit_name ())
+ else (
+ F.pp_print_string fmt "(";
+ Collections.List.iter_link
+ (fun () ->
+ F.pp_print_space fmt ();
+ let product =
+ match !backend with
+ | FStar -> "&"
+ | Coq -> "*"
+ | Lean -> "×"
+ | HOL4 -> "#"
+ in
+ F.pp_print_string fmt product;
+ F.pp_print_space fmt ())
+ (extract_rec true) generics.types;
+ F.pp_print_string fmt ")")
+ | AdtId _ | Assumed _ -> (
+ (* HOL4 behaves differently. Where in Coq/FStar/Lean we would write:
+ `tree a b`
+
+ In HOL4 we would write:
+ `('a, 'b) tree`
+ *)
+ match !backend with
+ | FStar | Coq | Lean ->
+ let print_paren = inside && has_params in
+ if print_paren then F.pp_print_string fmt "(";
+ (* TODO: for now, only the opaque *functions* are extracted in the
+ opaque module. The opaque *types* are assumed. *)
+ F.pp_print_string fmt (ctx_get_type type_id ctx);
+ (* We might need to filter the type arguments, if the type
+ is builtin (for instance, we filter the global allocator type
+ argument for `Vec`). *)
+ let generics =
+ match type_id with
+ | AdtId id -> (
+ match
+ TypeDeclId.Map.find_opt id ctx.types_filter_type_args_map
+ with
+ | None -> generics
+ | Some filter ->
+ let types = List.combine filter generics.types in
+ let types =
+ List.filter_map
+ (fun (b, ty) -> if b then Some ty else None)
+ types
+ in
+ { generics with types })
+ | _ -> generics
+ in
+ extract_generic_args ctx fmt no_params_tys generics;
+ if print_paren then F.pp_print_string fmt ")"
+ | HOL4 ->
+ let { types; const_generics; trait_refs } = generics in
+ (* Const generics are not supported in HOL4 *)
+ assert (const_generics = []);
+ let print_tys =
+ match type_id with
+ | AdtId id -> not (TypeDeclId.Set.mem id no_params_tys)
+ | Assumed _ -> true
+ | _ -> raise (Failure "Unreachable")
+ in
+ if types <> [] && print_tys then (
+ let print_paren = List.length types > 1 in
+ if print_paren then F.pp_print_string fmt "(";
+ Collections.List.iter_link
+ (fun () ->
+ F.pp_print_string fmt ",";
+ F.pp_print_space fmt ())
+ (extract_rec true) types;
+ if print_paren then F.pp_print_string fmt ")";
+ F.pp_print_space fmt ());
+ F.pp_print_string fmt (ctx_get_type type_id ctx);
+ if trait_refs <> [] then (
+ F.pp_print_space fmt ();
+ Collections.List.iter_link (F.pp_print_space fmt)
+ (extract_trait_ref ctx fmt no_params_tys true)
+ trait_refs)))
+ | TypeVar vid -> F.pp_print_string fmt (ctx_get_type_var vid ctx)
+ | Literal lty -> extract_literal_type ctx fmt lty
+ | Arrow (arg_ty, ret_ty) ->
+ if inside then F.pp_print_string fmt "(";
+ extract_rec false arg_ty;
+ F.pp_print_space fmt ();
+ extract_arrow fmt ();
+ F.pp_print_space fmt ();
+ extract_rec false ret_ty;
+ if inside then F.pp_print_string fmt ")"
+ | TraitType (trait_ref, generics, type_name) -> (
+ if !parameterize_trait_types then raise (Failure "Unimplemented")
+ else
+ let type_name =
+ ctx_get_trait_type trait_ref.trait_decl_ref.trait_decl_id type_name
+ ctx
+ in
+ let add_brackets (s : string) =
+ if !backend = Coq then "(" ^ s ^ ")" else s
+ in
+ (* There may be a special treatment depending on the instance id.
+ See the comments for {!extract_trait_instance_id_with_dot}.
+ TODO: there should be a cleaner way to do. The annoying thing
+ here is that if we project directly over the self clause, then
+ we have to be careful (we may not have to print the "Self.").
+ Otherwise, we can directly call {!extract_trait_ref}.
+ *)
+ match trait_ref.trait_id with
+ | Self ->
+ assert (generics = empty_generic_args);
+ assert (trait_ref.generics = empty_generic_args);
+ extract_trait_instance_id_with_dot ctx fmt no_params_tys false
+ trait_ref.trait_id;
+ F.pp_print_string fmt type_name
+ | _ ->
+ (* HOL4 doesn't have 1st class types *)
+ assert (!backend <> HOL4);
+ let use_brackets = generics <> empty_generic_args in
+ if use_brackets then F.pp_print_string fmt "(";
+ extract_trait_ref ctx fmt no_params_tys false trait_ref;
+ extract_generic_args ctx fmt no_params_tys generics;
+ if use_brackets then F.pp_print_string fmt ")";
+ F.pp_print_string fmt ("." ^ add_brackets type_name))
+
+and extract_trait_ref (ctx : extraction_ctx) (fmt : F.formatter)
+ (no_params_tys : TypeDeclId.Set.t) (inside : bool) (tr : trait_ref) : unit =
+ let use_brackets = tr.generics <> empty_generic_args && inside in
+ if use_brackets then F.pp_print_string fmt "(";
+ (* We may need to filter the parameters if the trait is builtin *)
+ let generics =
+ match tr.trait_id with
+ | TraitImpl id -> (
+ match
+ TraitImplId.Map.find_opt id ctx.trait_impls_filter_type_args_map
+ with
+ | None -> tr.generics
+ | Some filter ->
+ let types =
+ List.filter_map
+ (fun (b, x) -> if b then Some x else None)
+ (List.combine filter tr.generics.types)
+ in
+ { tr.generics with types })
+ | _ -> tr.generics
+ in
+ extract_trait_instance_id ctx fmt no_params_tys inside tr.trait_id;
+ extract_generic_args ctx fmt no_params_tys generics;
+ if use_brackets then F.pp_print_string fmt ")"
+
+and extract_trait_decl_ref (ctx : extraction_ctx) (fmt : F.formatter)
+ (no_params_tys : TypeDeclId.Set.t) (inside : bool) (tr : trait_decl_ref) :
+ unit =
+ let use_brackets = tr.decl_generics <> empty_generic_args && inside in
+ let name = ctx_get_trait_decl tr.trait_decl_id ctx in
+ if use_brackets then F.pp_print_string fmt "(";
+ F.pp_print_string fmt name;
+ (* There is something subtle here: the trait obligations for the implemented
+ trait are put inside the parent clauses, so we must ignore them here *)
+ let generics = { tr.decl_generics with trait_refs = [] } in
+ extract_generic_args ctx fmt no_params_tys generics;
+ if use_brackets then F.pp_print_string fmt ")"
+
+and extract_generic_args (ctx : extraction_ctx) (fmt : F.formatter)
+ (no_params_tys : TypeDeclId.Set.t) (generics : generic_args) : unit =
+ let { types; const_generics; trait_refs } = generics in
+ if !backend <> HOL4 then (
+ if types <> [] then (
+ F.pp_print_space fmt ();
+ Collections.List.iter_link (F.pp_print_space fmt)
+ (extract_ty ctx fmt no_params_tys true)
+ types);
+ if const_generics <> [] then (
+ assert (!backend <> HOL4);
+ F.pp_print_space fmt ();
+ Collections.List.iter_link (F.pp_print_space fmt)
+ (extract_const_generic ctx fmt true)
+ const_generics));
+ if trait_refs <> [] then (
+ F.pp_print_space fmt ();
+ Collections.List.iter_link (F.pp_print_space fmt)
+ (extract_trait_ref ctx fmt no_params_tys true)
+ trait_refs)
+
+(** We sometimes need to ignore references to `Self` when generating the
+ code, espcially when we project associated items. For this reason we
+ have a special function for the cases where we project from an instance
+ id (e.g., `<Self as Foo>::foo` - note that in the extracted code, the
+ projections are often written with a dot '.').
+ *)
+and extract_trait_instance_id_with_dot (ctx : extraction_ctx)
+ (fmt : F.formatter) (no_params_tys : TypeDeclId.Set.t) (inside : bool)
+ (id : trait_instance_id) : unit =
+ match id with
+ | Self ->
+ (* There are two situations:
+ - we are extracting a declared item and need to refer to another
+ item (for instance, we are extracting a method signature and
+ need to refer to an associated type).
+ We directly refer to the other item (we extract trait declarations
+ as structures, so we can refer to their fields)
+ - we are extracting a provided method for a trait declaration. We
+ refer to the item in the self trait clause (see {!SelfTraitClauseId}).
+
+ Remark: we can't get there for trait *implementations* because then the
+ types should have been normalized.
+ *)
+ if ctx.is_provided_method then
+ (* Provided method: use the trait self clause *)
+ let self_clause = ctx_get_trait_self_clause ctx in
+ F.pp_print_string fmt (self_clause ^ ".")
+ else
+ (* Declaration: nothing to print, we will directly refer to
+ the item. *)
+ ()
+ | _ ->
+ (* Other cases *)
+ extract_trait_instance_id ctx fmt no_params_tys inside id;
+ F.pp_print_string fmt "."
+
+and extract_trait_instance_id (ctx : extraction_ctx) (fmt : F.formatter)
+ (no_params_tys : TypeDeclId.Set.t) (inside : bool) (id : trait_instance_id)
+ : unit =
+ let add_brackets (s : string) = if !backend = Coq then "(" ^ s ^ ")" else s in
+ match id with
+ | Self ->
+ (* This has a specific treatment depending on the item we're extracting
+ (associated type, etc.). We should have caught this elsewhere. *)
+ if !Config.fail_hard then
+ raise (Failure "Unexpected occurrence of `Self`")
+ else F.pp_print_string fmt "ERROR(\"Unexpected Self\")"
+ | TraitImpl id ->
+ let name = ctx_get_trait_impl id ctx in
+ F.pp_print_string fmt name
+ | Clause id ->
+ let name = ctx_get_local_trait_clause id ctx in
+ F.pp_print_string fmt name
+ | ParentClause (inst_id, decl_id, clause_id) ->
+ (* Use the trait decl id to lookup the name *)
+ let name = ctx_get_trait_parent_clause decl_id clause_id ctx in
+ extract_trait_instance_id_with_dot ctx fmt no_params_tys true inst_id;
+ F.pp_print_string fmt (add_brackets name)
+ | ItemClause (inst_id, decl_id, item_name, clause_id) ->
+ (* Use the trait decl id to lookup the name *)
+ let name = ctx_get_trait_item_clause decl_id item_name clause_id ctx in
+ extract_trait_instance_id_with_dot ctx fmt no_params_tys true inst_id;
+ F.pp_print_string fmt (add_brackets name)
+ | TraitRef trait_ref ->
+ extract_trait_ref ctx fmt no_params_tys inside trait_ref
+ | UnknownTrait _ ->
+ (* This is an error case *)
+ raise (Failure "Unexpected")
+
+(** Compute the names for all the top-level identifiers used in a type
+ definition (type name, variant names, field names, etc. but not type
+ parameters).
+
+ We need to do this preemptively, beforce extracting any definition,
+ because of recursive definitions.
+ *)
+let extract_type_decl_register_names (ctx : extraction_ctx) (def : type_decl) :
+ extraction_ctx =
+ (* Lookup the builtin information, if there is *)
+ let open ExtractBuiltin in
+ let sname = name_to_simple_name def.name in
+ let info = SimpleNameMap.find_opt sname (builtin_types_map ()) in
+ (* Register the filtering information, if there is *)
+ let ctx =
+ match info with
+ | Some { keep_params = Some keep; _ } ->
+ {
+ ctx with
+ types_filter_type_args_map =
+ TypeDeclId.Map.add def.def_id keep ctx.types_filter_type_args_map;
+ }
+ | _ -> ctx
+ in
+ (* Compute and register the type def name *)
+ let def_name =
+ match info with
+ | None -> ctx.fmt.type_name def.name
+ | Some info -> info.extract_name
+ in
+ let ctx = ctx_add (TypeId (AdtId def.def_id)) def_name ctx in
+ (* Compute and register:
+ * - the variant names, if this is an enumeration
+ * - the field names, if this is a structure
+ *)
+ let ctx =
+ match def.kind with
+ | Struct fields ->
+ (* Compute the names *)
+ let field_names, cons_name =
+ match info with
+ | None | Some { body_info = None; _ } ->
+ let field_names =
+ FieldId.mapi
+ (fun fid (field : field) ->
+ (fid, ctx.fmt.field_name def.name fid field.field_name))
+ fields
+ in
+ let cons_name = ctx.fmt.struct_constructor def.name in
+ (field_names, cons_name)
+ | Some { body_info = Some (Struct (cons_name, field_names)); _ } ->
+ let field_names =
+ FieldId.mapi
+ (fun fid (field : field) ->
+ let rust_name = Option.get field.field_name in
+ let name =
+ snd (List.find (fun (n, _) -> n = rust_name) field_names)
+ in
+ (fid, name))
+ fields
+ in
+ (field_names, cons_name)
+ | Some info ->
+ raise
+ (Failure
+ ("Invalid builtin information: "
+ ^ show_builtin_type_info info))
+ in
+ (* Add the fields *)
+ let ctx =
+ List.fold_left
+ (fun ctx (fid, name) ->
+ ctx_add (FieldId (AdtId def.def_id, fid)) name ctx)
+ ctx field_names
+ in
+ (* Add the constructor name *)
+ ctx_add (StructId (AdtId def.def_id)) cons_name ctx
+ | Enum variants ->
+ let variant_names =
+ match info with
+ | None ->
+ VariantId.mapi
+ (fun variant_id (variant : variant) ->
+ let name =
+ ctx.fmt.variant_name def.name variant.variant_name
+ in
+ (* Add the type name prefix for Lean *)
+ let name =
+ if !Config.backend = Lean then
+ let type_name = ctx.fmt.type_name def.name in
+ type_name ^ "." ^ name
+ else name
+ in
+ (variant_id, name))
+ variants
+ | Some { body_info = Some (Enum variant_infos); _ } ->
+ (* We need to compute the map from variant to variant *)
+ let variant_map =
+ StringMap.of_list
+ (List.map
+ (fun (info : builtin_enum_variant_info) ->
+ (info.rust_variant_name, info.extract_variant_name))
+ variant_infos)
+ in
+ VariantId.mapi
+ (fun variant_id (variant : variant) ->
+ (variant_id, StringMap.find variant.variant_name variant_map))
+ variants
+ | _ -> raise (Failure "Invalid builtin information")
+ in
+ List.fold_left
+ (fun ctx (vid, vname) ->
+ ctx_add (VariantId (AdtId def.def_id, vid)) vname ctx)
+ ctx variant_names
+ | Opaque ->
+ (* Nothing to do *)
+ ctx
+ in
+ (* Return *)
+ ctx
+
+(** Print the variants *)
+let extract_type_decl_variant (ctx : extraction_ctx) (fmt : F.formatter)
+ (type_decl_group : TypeDeclId.Set.t) (type_name : string)
+ (type_params : string list) (cg_params : string list) (cons_name : string)
+ (fields : field list) : unit =
+ F.pp_print_space fmt ();
+ (* variant box *)
+ F.pp_open_hvbox fmt ctx.indent_incr;
+ (* [| Cons :]
+ * Note that we really don't want any break above so we print everything
+ * at once. *)
+ let opt_colon = if !backend <> HOL4 then " :" else "" in
+ F.pp_print_string fmt ("| " ^ cons_name ^ opt_colon);
+ let print_field (fid : FieldId.id) (f : field) (ctx : extraction_ctx) :
+ extraction_ctx =
+ F.pp_print_space fmt ();
+ (* Open the field box *)
+ F.pp_open_box fmt ctx.indent_incr;
+ (* Print the field names, if the backend accepts it.
+ * [ x :]
+ * Note that when printing fields, we register the field names as
+ * *variables*: they don't need to be unique at the top level. *)
+ let ctx =
+ match !backend with
+ | FStar -> (
+ match f.field_name with
+ | None -> ctx
+ | Some field_name ->
+ let var_id = VarId.of_int (FieldId.to_int fid) in
+ let field_name =
+ ctx.fmt.var_basename ctx.names_maps.names_map.names_set
+ (Some field_name) f.field_ty
+ in
+ let ctx, field_name = ctx_add_var field_name var_id ctx in
+ F.pp_print_string fmt (field_name ^ " :");
+ F.pp_print_space fmt ();
+ ctx)
+ | Coq | Lean | HOL4 -> ctx
+ in
+ (* Print the field type *)
+ let inside = !backend = HOL4 in
+ extract_ty ctx fmt type_decl_group inside f.field_ty;
+ (* Print the arrow [->] *)
+ if !backend <> HOL4 then (
+ F.pp_print_space fmt ();
+ extract_arrow fmt ());
+ (* Close the field box *)
+ F.pp_close_box fmt ();
+ (* Return *)
+ ctx
+ in
+ (* Print the fields *)
+ let fields = FieldId.mapi (fun fid f -> (fid, f)) fields in
+ let _ =
+ List.fold_left (fun ctx (fid, f) -> print_field fid f ctx) ctx fields
+ in
+ (* Sanity check: HOL4 doesn't support const generics *)
+ assert (cg_params = [] || !backend <> HOL4);
+ (* Print the final type *)
+ if !backend <> HOL4 then (
+ F.pp_print_space fmt ();
+ F.pp_open_hovbox fmt 0;
+ F.pp_print_string fmt type_name;
+ List.iter
+ (fun p ->
+ F.pp_print_space fmt ();
+ F.pp_print_string fmt p)
+ (List.append type_params cg_params);
+ F.pp_close_box fmt ());
+ (* Close the variant box *)
+ F.pp_close_box fmt ()
+
+(* TODO: we don' need the [def_name] paramter: it can be retrieved from the context *)
+let extract_type_decl_enum_body (ctx : extraction_ctx) (fmt : F.formatter)
+ (type_decl_group : TypeDeclId.Set.t) (def : type_decl) (def_name : string)
+ (type_params : string list) (cg_params : string list)
+ (variants : variant list) : unit =
+ (* We want to generate a definition which looks like this (taking F* as example):
+ {[
+ type list a = | Cons : a -> list a -> list a | Nil : list a
+ ]}
+
+ If there isn't enough space on one line:
+ {[
+ type s =
+ | Cons : a -> list a -> list a
+ | Nil : list a
+ ]}
+
+ And if we need to write the type of a variant on several lines:
+ {[
+ type s =
+ | Cons :
+ a ->
+ list a ->
+ list a
+ | Nil : list a
+ ]}
+
+ Finally, it is possible to give names to the variant fields in Rust.
+ In this situation, we generate a definition like this:
+ {[
+ type s =
+ | Cons : hd:a -> tl:list a -> list a
+ | Nil : list a
+ ]}
+
+ Note that we already printed: [type s =]
+ *)
+ let print_variant _variant_id (v : variant) =
+ (* We don't lookup the name, because it may have a prefix for the type
+ id (in the case of Lean) *)
+ let cons_name = ctx.fmt.variant_name def.name v.variant_name in
+ let fields = v.fields in
+ extract_type_decl_variant ctx fmt type_decl_group def_name type_params
+ cg_params cons_name fields
+ in
+ (* Print the variants *)
+ let variants = VariantId.mapi (fun vid v -> (vid, v)) variants in
+ List.iter (fun (vid, v) -> print_variant vid v) variants
+
+let extract_type_decl_struct_body (ctx : extraction_ctx) (fmt : F.formatter)
+ (type_decl_group : TypeDeclId.Set.t) (kind : decl_kind) (def : type_decl)
+ (type_params : string list) (cg_params : string list) (fields : field list)
+ : unit =
+ (* We want to generate a definition which looks like this (taking F* as example):
+ {[
+ type t = { x : int; y : bool; }
+ ]}
+
+ If there isn't enough space on one line:
+ {[
+ type t =
+ {
+ x : int; y : bool;
+ }
+ ]}
+
+ And if there is even less space:
+ {[
+ type t =
+ {
+ x : int;
+ y : bool;
+ }
+ ]}
+
+ Also, in case there are no fields, we need to define the type as [unit]
+ ([type t = {}] doesn't work in F* ).
+
+ Coq:
+ ====
+ We need to define the constructor name upon defining the struct (record, in Coq).
+ The syntex is:
+ {[
+ Record Foo = mkFoo { x : int; y : bool; }.
+ }]
+
+ Also, Coq doesn't support groups of mutually recursive inductives and records.
+ This is fine, because we can then define records as inductives, and leverage
+ the fact that when record fields are accessed, the records are symbolically
+ expanded which introduces let bindings of the form: [let RecordCons ... = x in ...].
+ As a consequence, we never use the record projectors (unless we reconstruct
+ them in the micro passes of course).
+
+ HOL4:
+ =====
+ Type definitions are written as follows:
+ {[
+ Datatype:
+ tree =
+ TLeaf 'a
+ | TNode node ;
+
+ node =
+ Node (tree list)
+ End
+ ]}
+ *)
+ (* Note that we already printed: [type t =] *)
+ let is_rec = decl_is_from_rec_group kind in
+ let _ =
+ if !backend = FStar && fields = [] then (
+ F.pp_print_space fmt ();
+ F.pp_print_string fmt (unit_name ()))
+ else if !backend = Lean && fields = [] then ()
+ (* If the definition is recursive, we may need to extract it as an inductive
+ (instead of a record). We start with the "normal" case: we extract it
+ as a record. *)
+ else if (not is_rec) || (!backend <> Coq && !backend <> Lean) then (
+ if !backend <> Lean then F.pp_print_space fmt ();
+ (* If Coq: print the constructor name *)
+ (* TODO: remove superfluous test not is_rec below *)
+ if !backend = Coq && not is_rec then (
+ F.pp_print_string fmt (ctx_get_struct (AdtId def.def_id) ctx);
+ F.pp_print_string fmt " ");
+ (match !backend with
+ | Lean -> ()
+ | FStar | Coq -> F.pp_print_string fmt "{"
+ | HOL4 -> F.pp_print_string fmt "<|");
+ F.pp_print_break fmt 1 ctx.indent_incr;
+ (* The body itself *)
+ (* Open a box for the body *)
+ (match !backend with
+ | Coq | FStar | HOL4 -> F.pp_open_hvbox fmt 0
+ | Lean -> F.pp_open_vbox fmt 0);
+ (* Print the fields *)
+ let print_field (field_id : FieldId.id) (f : field) : unit =
+ let field_name = ctx_get_field (AdtId def.def_id) field_id ctx in
+ (* Open a box for the field *)
+ F.pp_open_box fmt ctx.indent_incr;
+ F.pp_print_string fmt field_name;
+ F.pp_print_space fmt ();
+ F.pp_print_string fmt ":";
+ F.pp_print_space fmt ();
+ extract_ty ctx fmt type_decl_group false f.field_ty;
+ if !backend <> Lean then F.pp_print_string fmt ";";
+ (* Close the box for the field *)
+ F.pp_close_box fmt ()
+ in
+ let fields = FieldId.mapi (fun fid f -> (fid, f)) fields in
+ Collections.List.iter_link (F.pp_print_space fmt)
+ (fun (fid, f) -> print_field fid f)
+ fields;
+ (* Close the box for the body *)
+ F.pp_close_box fmt ();
+ match !backend with
+ | Lean -> ()
+ | FStar | Coq ->
+ F.pp_print_space fmt ();
+ F.pp_print_string fmt "}"
+ | HOL4 ->
+ F.pp_print_space fmt ();
+ F.pp_print_string fmt "|>")
+ else (
+ (* We extract for Coq or Lean, and we have a recursive record, or a record in
+ a group of mutually recursive types: we extract it as an inductive type *)
+ assert (is_rec && (!backend = Coq || !backend = Lean));
+ (* Small trick: in Lean we use namespaces, meaning we don't need to prefix
+ the constructor name with the name of the type at definition site,
+ i.e., instead of generating `inductive Foo := | MkFoo ...` like in Coq
+ we generate `inductive Foo := | mk ... *)
+ let cons_name =
+ if !backend = Lean then "mk" else ctx_get_struct (AdtId def.def_id) ctx
+ in
+ let def_name = ctx_get_local_type def.def_id ctx in
+ extract_type_decl_variant ctx fmt type_decl_group def_name type_params
+ cg_params cons_name fields)
+ in
+ ()
+
+(** Extract a nestable, muti-line comment *)
+let extract_comment (fmt : F.formatter) (sl : string list) : unit =
+ (* Delimiters, space after we break a line *)
+ let ld, space, rd =
+ match !backend with
+ | Coq | FStar | HOL4 -> ("(** ", 4, " *)")
+ | Lean -> ("/- ", 3, " -/")
+ in
+ F.pp_open_vbox fmt space;
+ F.pp_print_string fmt ld;
+ (match sl with
+ | [] -> ()
+ | s :: sl ->
+ F.pp_print_string fmt s;
+ List.iter
+ (fun s ->
+ F.pp_print_space fmt ();
+ F.pp_print_string fmt s)
+ sl);
+ F.pp_print_string fmt rd;
+ F.pp_close_box fmt ()
+
+let extract_trait_clause_type (ctx : extraction_ctx) (fmt : F.formatter)
+ (no_params_tys : TypeDeclId.Set.t) (clause : trait_clause) : unit =
+ let trait_name = ctx_get_trait_decl clause.trait_id ctx in
+ F.pp_print_string fmt trait_name;
+ extract_generic_args ctx fmt no_params_tys clause.generics
+
+(** Insert a space, if necessary *)
+let insert_req_space (fmt : F.formatter) (space : bool ref) : unit =
+ if !space then space := false else F.pp_print_space fmt ()
+
+(** Extract the trait self clause.
+
+ We add the trait self clause for provided methods (see {!TraitSelfClauseId}).
+ *)
+let extract_trait_self_clause (insert_req_space : unit -> unit)
+ (ctx : extraction_ctx) (fmt : F.formatter) (trait_decl : trait_decl)
+ (params : string list) : unit =
+ insert_req_space ();
+ F.pp_print_string fmt "(";
+ let self_clause = ctx_get_trait_self_clause ctx in
+ F.pp_print_string fmt self_clause;
+ F.pp_print_space fmt ();
+ F.pp_print_string fmt ":";
+ F.pp_print_space fmt ();
+ let trait_id = ctx_get_trait_decl trait_decl.def_id ctx in
+ F.pp_print_string fmt trait_id;
+ List.iter
+ (fun p ->
+ F.pp_print_space fmt ();
+ F.pp_print_string fmt p)
+ params;
+ F.pp_print_string fmt ")"
+
+(**
+ - [trait_decl]: if [Some], it means we are extracting the generics for a provided
+ method and need to insert a trait self clause (see {!TraitSelfClauseId}).
+ *)
+let extract_generic_params (ctx : extraction_ctx) (fmt : F.formatter)
+ (no_params_tys : TypeDeclId.Set.t) ?(use_forall = false)
+ ?(use_forall_use_sep = true) ?(use_arrows = false)
+ ?(as_implicits : bool = false) ?(space : bool ref option = None)
+ ?(trait_decl : trait_decl option = None) (generics : generic_params)
+ (type_params : string list) (cg_params : string list)
+ (trait_clauses : string list) : unit =
+ let all_params = List.concat [ type_params; cg_params; trait_clauses ] in
+ (* HOL4 doesn't support const generics *)
+ assert (cg_params = [] || !backend <> HOL4);
+ let left_bracket (implicit : bool) =
+ if implicit && !backend <> FStar then F.pp_print_string fmt "{"
+ else F.pp_print_string fmt "("
+ in
+ let right_bracket (implicit : bool) =
+ if implicit && !backend <> FStar then F.pp_print_string fmt "}"
+ else F.pp_print_string fmt ")"
+ in
+ let print_implicit_symbol (implicit : bool) =
+ if implicit && !backend = FStar then F.pp_print_string fmt "#" else ()
+ in
+ let insert_req_space () =
+ match space with
+ | None -> F.pp_print_space fmt ()
+ | Some space -> insert_req_space fmt space
+ in
+ (* Print the type/const generic parameters *)
+ if all_params <> [] then (
+ if use_forall then (
+ if use_forall_use_sep then (
+ insert_req_space ();
+ F.pp_print_string fmt ":");
+ insert_req_space ();
+ F.pp_print_string fmt "forall");
+ (* Small helper - we may need to split the parameters *)
+ let print_generics (as_implicits : bool) (type_params : string list)
+ (const_generics : const_generic_var list)
+ (trait_clauses : trait_clause list) : unit =
+ (* Note that in HOL4 we don't print the type parameters. *)
+ if !backend <> HOL4 then (
+ (* Print the type parameters *)
+ if type_params <> [] then (
+ insert_req_space ();
+ (* ( *)
+ left_bracket as_implicits;
+ List.iter
+ (fun s ->
+ print_implicit_symbol as_implicits;
+ F.pp_print_string fmt s;
+ F.pp_print_space fmt ())
+ type_params;
+ F.pp_print_string fmt ":";
+ F.pp_print_space fmt ();
+ F.pp_print_string fmt (type_keyword ());
+ (* ) *)
+ right_bracket as_implicits;
+ if use_arrows then (
+ F.pp_print_space fmt ();
+ F.pp_print_string fmt "->"));
+ (* Print the const generic parameters *)
+ List.iter
+ (fun (var : const_generic_var) ->
+ insert_req_space ();
+ (* ( *)
+ left_bracket as_implicits;
+ let n = ctx_get_const_generic_var var.index ctx in
+ print_implicit_symbol as_implicits;
+ F.pp_print_string fmt n;
+ F.pp_print_space fmt ();
+ F.pp_print_string fmt ":";
+ F.pp_print_space fmt ();
+ extract_literal_type ctx fmt var.ty;
+ (* ) *)
+ right_bracket as_implicits;
+ if use_arrows then (
+ F.pp_print_space fmt ();
+ F.pp_print_string fmt "->"))
+ const_generics);
+ (* Print the trait clauses *)
+ List.iter
+ (fun (clause : trait_clause) ->
+ insert_req_space ();
+ (* ( *)
+ left_bracket as_implicits;
+ let n = ctx_get_local_trait_clause clause.clause_id ctx in
+ print_implicit_symbol as_implicits;
+ F.pp_print_string fmt n;
+ F.pp_print_space fmt ();
+ F.pp_print_string fmt ":";
+ F.pp_print_space fmt ();
+ extract_trait_clause_type ctx fmt no_params_tys clause;
+ (* ) *)
+ right_bracket as_implicits;
+ if use_arrows then (
+ F.pp_print_space fmt ();
+ F.pp_print_string fmt "->"))
+ trait_clauses
+ in
+ (* If we extract the generics for a provided method for a trait declaration
+ (indicated by the trait decl given as input), we need to split the generics:
+ - we print the generics for the trait decl
+ - we print the trait self clause
+ - we print the generics for the trait method
+ *)
+ match trait_decl with
+ | None ->
+ print_generics as_implicits type_params generics.const_generics
+ generics.trait_clauses
+ | Some trait_decl ->
+ (* Split the generics between the generics specific to the trait decl
+ and those specific to the trait method *)
+ let open Collections.List in
+ let dtype_params, mtype_params =
+ split_at type_params (length trait_decl.generics.types)
+ in
+ let dcgs, mcgs =
+ split_at generics.const_generics
+ (length trait_decl.generics.const_generics)
+ in
+ let dtrait_clauses, mtrait_clauses =
+ split_at generics.trait_clauses
+ (length trait_decl.generics.trait_clauses)
+ in
+ (* Extract the trait decl generics - note that we can always deduce
+ those parameters from the trait self clause: for this reason
+ they are always implicit *)
+ print_generics true dtype_params dcgs dtrait_clauses;
+ (* Extract the trait self clause *)
+ let params =
+ concat
+ [
+ dtype_params;
+ map
+ (fun (cg : const_generic_var) ->
+ ctx_get_const_generic_var cg.index ctx)
+ dcgs;
+ map
+ (fun c -> ctx_get_local_trait_clause c.clause_id ctx)
+ dtrait_clauses;
+ ]
+ in
+ extract_trait_self_clause insert_req_space ctx fmt trait_decl params;
+ (* Extract the method generics *)
+ print_generics as_implicits mtype_params mcgs mtrait_clauses)
+
+(** Extract a type declaration.
+
+ This function is for all type declarations and all backends **at the exception**
+ of opaque (assumed/declared) types format4 HOL4.
+
+ See {!extract_type_decl}.
+ *)
+let extract_type_decl_gen (ctx : extraction_ctx) (fmt : F.formatter)
+ (type_decl_group : TypeDeclId.Set.t) (kind : decl_kind) (def : type_decl)
+ (extract_body : bool) : unit =
+ (* Sanity check *)
+ assert (extract_body || !backend <> HOL4);
+ let type_kind =
+ if extract_body then
+ match def.kind with
+ | Struct _ -> Some Struct
+ | Enum _ -> Some Enum
+ | Opaque -> None
+ else None
+ in
+ (* If in Coq and the declaration is opaque, it must have the shape:
+ [Axiom Ident : forall (T0 ... Tn : Type) (N0 : ...) ... (Nn : ...), ... -> ... -> ...].
+
+ The boolean [is_opaque_coq] is used to detect this case.
+ *)
+ let is_opaque = type_kind = None in
+ let is_opaque_coq = !backend = Coq && is_opaque in
+ let use_forall = is_opaque_coq && def.generics <> empty_generic_params in
+ (* Retrieve the definition name *)
+ let def_name = ctx_get_local_type def.def_id ctx in
+ (* Add the type and const generic params - note that we need those bindings only for the
+ * body translation (they are not top-level) *)
+ let ctx_body, type_params, cg_params, trait_clauses =
+ ctx_add_generic_params def.generics ctx
+ in
+ (* Add a break before *)
+ if !backend <> HOL4 || not (decl_is_first_from_group kind) then
+ F.pp_print_break fmt 0 0;
+ (* Print a comment to link the extracted type to its original rust definition *)
+ extract_comment fmt [ "[" ^ Print.name_to_string def.name ^ "]" ];
+ F.pp_print_break fmt 0 0;
+ (* Open a box for the definition, so that whenever possible it gets printed on
+ * one line. Note however that in the case of Lean line breaks are important
+ * for parsing: we thus use a hovbox. *)
+ (match !backend with
+ | Coq | FStar | HOL4 -> F.pp_open_hvbox fmt 0
+ | Lean -> F.pp_open_vbox fmt 0);
+ (* Open a box for "type TYPE_NAME (TYPE_PARAMS CONST_GEN_PARAMS) =" *)
+ F.pp_open_hovbox fmt ctx.indent_incr;
+ (* > "type TYPE_NAME" *)
+ let qualif = ctx.fmt.type_decl_kind_to_qualif kind type_kind in
+ (match qualif with
+ | Some qualif -> F.pp_print_string fmt (qualif ^ " " ^ def_name)
+ | None -> F.pp_print_string fmt def_name);
+ (* HOL4 doesn't support const generics, and type definitions in HOL4 don't
+ support trait clauses *)
+ assert ((cg_params = [] && trait_clauses = []) || !backend <> HOL4);
+ (* Print the generic parameters *)
+ extract_generic_params ctx_body fmt type_decl_group ~use_forall def.generics
+ type_params cg_params trait_clauses;
+ (* Print the "=" if we extract the body*)
+ if extract_body then (
+ F.pp_print_space fmt ();
+ let eq =
+ match !backend with
+ | FStar -> "="
+ | Coq -> ":="
+ | Lean ->
+ if type_kind = Some Struct && kind = SingleNonRec then "where"
+ else ":="
+ | HOL4 -> "="
+ in
+ F.pp_print_string fmt eq)
+ else (
+ (* Otherwise print ": Type", unless it is the HOL4 backend (in
+ which case we declare the type with `new_type`) *)
+ if use_forall then F.pp_print_string fmt ","
+ else (
+ F.pp_print_space fmt ();
+ F.pp_print_string fmt ":");
+ F.pp_print_space fmt ();
+ F.pp_print_string fmt (type_keyword ()));
+ (* Close the box for "type TYPE_NAME (TYPE_PARAMS) =" *)
+ F.pp_close_box fmt ();
+ (if extract_body then
+ match def.kind with
+ | Struct fields ->
+ extract_type_decl_struct_body ctx_body fmt type_decl_group kind def
+ type_params cg_params fields
+ | Enum variants ->
+ extract_type_decl_enum_body ctx_body fmt type_decl_group def def_name
+ type_params cg_params variants
+ | Opaque -> raise (Failure "Unreachable"));
+ (* Add the definition end delimiter *)
+ if !backend = HOL4 && decl_is_not_last_from_group kind then (
+ F.pp_print_space fmt ();
+ F.pp_print_string fmt ";")
+ else if !backend = Coq && decl_is_last_from_group kind then (
+ (* This is actually an end of group delimiter. For aesthetic reasons
+ we print it here instead of in {!end_type_decl_group}. *)
+ F.pp_print_cut fmt ();
+ F.pp_print_string fmt ".");
+ (* Close the box for the definition *)
+ F.pp_close_box fmt ();
+ (* Add breaks to insert new lines between definitions *)
+ if !backend <> HOL4 || decl_is_not_last_from_group kind then
+ F.pp_print_break fmt 0 0
+
+(** Extract an opaque type declaration to HOL4.
+
+ Remark (SH): having to treat this specific case separately is very annoying,
+ but I could not find a better way.
+ *)
+let extract_type_decl_hol4_opaque (ctx : extraction_ctx) (fmt : F.formatter)
+ (def : type_decl) : unit =
+ (* Retrieve the definition name *)
+ let def_name = ctx_get_local_type def.def_id ctx in
+ (* Generic parameters are unsupported *)
+ assert (def.generics.const_generics = []);
+ (* Trait clauses on type definitions are unsupported *)
+ assert (def.generics.trait_clauses = []);
+ (* Types *)
+ (* Count the number of parameters *)
+ let num_params = List.length def.generics.types in
+ (* Generate the declaration *)
+ F.pp_print_space fmt ();
+ F.pp_print_string fmt
+ ("val _ = new_type (\"" ^ def_name ^ "\", " ^ string_of_int num_params ^ ")");
+ F.pp_print_space fmt ()
+
+(** Extract an empty record type declaration to HOL4.
+
+ Empty records are not supported in HOL4, so we extract them as type
+ abbreviations to the unit type.
+
+ Remark (SH): having to treat this specific case separately is very annoying,
+ but I could not find a better way.
+ *)
+let extract_type_decl_hol4_empty_record (ctx : extraction_ctx)
+ (fmt : F.formatter) (def : type_decl) : unit =
+ (* Retrieve the definition name *)
+ let def_name = ctx_get_local_type def.def_id ctx in
+ (* Sanity check *)
+ assert (def.generics = empty_generic_params);
+ (* Generate the declaration *)
+ F.pp_print_space fmt ();
+ F.pp_print_string fmt ("Type " ^ def_name ^ " = “: unit”");
+ F.pp_print_space fmt ()
+
+(** Extract a type declaration.
+
+ Note that all the names used for extraction should already have been
+ registered.
+
+ This function should be inserted between calls to {!start_type_decl_group}
+ and {!end_type_decl_group}.
+ *)
+let extract_type_decl (ctx : extraction_ctx) (fmt : F.formatter)
+ (type_decl_group : TypeDeclId.Set.t) (kind : decl_kind) (def : type_decl) :
+ unit =
+ let extract_body =
+ match kind with
+ | SingleNonRec | SingleRec | MutRecFirst | MutRecInner | MutRecLast -> true
+ | Assumed | Declared -> false
+ in
+ if extract_body then
+ if !backend = HOL4 && is_empty_record_type_decl def then
+ extract_type_decl_hol4_empty_record ctx fmt def
+ else extract_type_decl_gen ctx fmt type_decl_group kind def extract_body
+ else
+ match !backend with
+ | FStar | Coq | Lean ->
+ extract_type_decl_gen ctx fmt type_decl_group kind def extract_body
+ | HOL4 -> extract_type_decl_hol4_opaque ctx fmt def
+
+(** Generate a [Argument] instruction in Coq to allow omitting implicit
+ arguments for variants, fields, etc..
+
+ For instance, provided we have this definition:
+ {[
+ Inductive result A :=
+ | Return : A -> result A
+ | Fail_ : error -> result A.
+ ]}
+
+ We may want to generate those instructions:
+ {[
+ Arguments Return {_} a.
+ Arguments Fail_ {_}.
+ ]}
+ *)
+let extract_coq_arguments_instruction (ctx : extraction_ctx) (fmt : F.formatter)
+ (cons_name : string) (num_implicit_params : int) : unit =
+ (* Add a break before *)
+ F.pp_print_break fmt 0 0;
+ (* Open a box *)
+ F.pp_open_hovbox fmt ctx.indent_incr;
+ F.pp_print_break fmt 0 0;
+ F.pp_print_string fmt "Arguments";
+ F.pp_print_space fmt ();
+ F.pp_print_string fmt cons_name;
+ (* Print the type/const params and the trait clauses (`{T}`) *)
+ F.pp_print_space fmt ();
+ F.pp_print_string fmt "{";
+ Collections.List.iter_times num_implicit_params (fun () ->
+ F.pp_print_space fmt ();
+ F.pp_print_string fmt "_");
+ F.pp_print_space fmt ();
+ F.pp_print_string fmt "}.";
+
+ (* Close the box *)
+ F.pp_close_box fmt ()
+
+(** Auxiliary function.
+
+ Generate [Arguments] instructions in Coq for type definitions.
+ *)
+let extract_type_decl_coq_arguments (ctx : extraction_ctx) (fmt : F.formatter)
+ (kind : decl_kind) (decl : type_decl) : unit =
+ assert (!backend = Coq);
+ (* Generating the [Arguments] instructions is useful only if there are parameters *)
+ let num_params =
+ List.length decl.generics.types
+ + List.length decl.generics.const_generics
+ + List.length decl.generics.trait_clauses
+ in
+ if num_params = 0 then ()
+ else
+ (* Generate the [Arguments] instruction *)
+ match decl.kind with
+ | Opaque -> ()
+ | Struct fields ->
+ let adt_id = AdtId decl.def_id in
+ (* Generate the instruction for the record constructor *)
+ let cons_name = ctx_get_struct adt_id ctx in
+ extract_coq_arguments_instruction ctx fmt cons_name num_params;
+ (* Generate the instruction for the record projectors, if there are *)
+ let is_rec = decl_is_from_rec_group kind in
+ if not is_rec then
+ FieldId.iteri
+ (fun fid _ ->
+ let cons_name = ctx_get_field adt_id fid ctx in
+ extract_coq_arguments_instruction ctx fmt cons_name num_params)
+ fields;
+ (* Add breaks to insert new lines between definitions *)
+ F.pp_print_break fmt 0 0
+ | Enum variants ->
+ (* Generate the instructions *)
+ VariantId.iteri
+ (fun vid (_ : variant) ->
+ let cons_name = ctx_get_variant (AdtId decl.def_id) vid ctx in
+ extract_coq_arguments_instruction ctx fmt cons_name num_params)
+ variants;
+ (* Add breaks to insert new lines between definitions *)
+ F.pp_print_break fmt 0 0
+
+(** Auxiliary function.
+
+ Generate field projectors in Coq.
+
+ Sometimes we extract records as inductives in Coq: when this happens we
+ have to define the field projectors afterwards.
+ *)
+let extract_type_decl_record_field_projectors (ctx : extraction_ctx)
+ (fmt : F.formatter) (kind : decl_kind) (decl : type_decl) : unit =
+ assert (!backend = Coq);
+ match decl.kind with
+ | Opaque | Enum _ -> ()
+ | Struct fields ->
+ (* Records are extracted as inductives only if they are recursive *)
+ let is_rec = decl_is_from_rec_group kind in
+ if is_rec then
+ (* Add the type params *)
+ let ctx, type_params, cg_params, trait_clauses =
+ ctx_add_generic_params decl.generics ctx
+ in
+ let ctx, record_var = ctx_add_var "x" (VarId.of_int 0) ctx in
+ let ctx, field_var = ctx_add_var "x" (VarId.of_int 1) ctx in
+ let def_name = ctx_get_local_type decl.def_id ctx in
+ let cons_name = ctx_get_struct (AdtId decl.def_id) ctx in
+ let extract_field_proj (field_id : FieldId.id) (_ : field) : unit =
+ F.pp_print_space fmt ();
+ (* Outer box for the projector definition *)
+ F.pp_open_hvbox fmt 0;
+ (* Inner box for the projector definition *)
+ F.pp_open_hvbox fmt ctx.indent_incr;
+ (* Open a box for the [Definition PROJ ... :=] *)
+ F.pp_open_hovbox fmt ctx.indent_incr;
+ F.pp_print_string fmt "Definition";
+ F.pp_print_space fmt ();
+ let field_name = ctx_get_field (AdtId decl.def_id) field_id ctx in
+ F.pp_print_string fmt field_name;
+ (* Print the generics *)
+ let as_implicits = true in
+ extract_generic_params ctx fmt TypeDeclId.Set.empty ~as_implicits
+ decl.generics type_params cg_params trait_clauses;
+ (* Print the record parameter *)
+ F.pp_print_space fmt ();
+ F.pp_print_string fmt "(";
+ F.pp_print_string fmt record_var;
+ F.pp_print_space fmt ();
+ F.pp_print_string fmt ":";
+ F.pp_print_space fmt ();
+ F.pp_print_string fmt def_name;
+ List.iter
+ (fun p ->
+ F.pp_print_space fmt ();
+ F.pp_print_string fmt p)
+ type_params;
+ F.pp_print_string fmt ")";
+ (* *)
+ F.pp_print_space fmt ();
+ F.pp_print_string fmt ":=";
+ (* Close the box for the [Definition PROJ ... :=] *)
+ F.pp_close_box fmt ();
+ F.pp_print_space fmt ();
+ (* Open a box for the whole match *)
+ F.pp_open_hvbox fmt 0;
+ (* Open a box for the [match ... with] *)
+ F.pp_open_hovbox fmt ctx.indent_incr;
+ F.pp_print_string fmt "match";
+ F.pp_print_space fmt ();
+ F.pp_print_string fmt record_var;
+ F.pp_print_space fmt ();
+ F.pp_print_string fmt "with";
+ (* Close the box for the [match ... with] *)
+ F.pp_close_box fmt ();
+
+ (* Open a box for the branch *)
+ F.pp_open_hovbox fmt ctx.indent_incr;
+ (* Print the match branch *)
+ F.pp_print_space fmt ();
+ F.pp_print_string fmt "|";
+ F.pp_print_space fmt ();
+ F.pp_print_string fmt cons_name;
+ FieldId.iteri
+ (fun id _ ->
+ F.pp_print_space fmt ();
+ if field_id = id then F.pp_print_string fmt field_var
+ else F.pp_print_string fmt "_")
+ fields;
+ F.pp_print_space fmt ();
+ F.pp_print_string fmt "=>";
+ F.pp_print_space fmt ();
+ F.pp_print_string fmt field_var;
+ (* Close the box for the branch *)
+ F.pp_close_box fmt ();
+ (* Print the [end] *)
+ F.pp_print_space fmt ();
+ F.pp_print_string fmt "end";
+ (* Close the box for the whole match *)
+ F.pp_close_box fmt ();
+ (* Close the inner box projector *)
+ F.pp_close_box fmt ();
+ (* If Coq: end the definition with a "." *)
+ if !backend = Coq then (
+ F.pp_print_cut fmt ();
+ F.pp_print_string fmt ".");
+ (* Close the outer box projector *)
+ F.pp_close_box fmt ();
+ (* Add breaks to insert new lines between definitions *)
+ F.pp_print_break fmt 0 0
+ in
+
+ let extract_proj_notation (field_id : FieldId.id) (_ : field) : unit =
+ F.pp_print_space fmt ();
+ (* Outer box for the projector definition *)
+ F.pp_open_hvbox fmt 0;
+ (* Inner box for the projector definition *)
+ F.pp_open_hovbox fmt ctx.indent_incr;
+ let ctx, record_var = ctx_add_var "x" (VarId.of_int 0) ctx in
+ F.pp_print_string fmt "Notation";
+ F.pp_print_space fmt ();
+ let field_name = ctx_get_field (AdtId decl.def_id) field_id ctx in
+ F.pp_print_string fmt ("\"" ^ record_var ^ " .(" ^ field_name ^ ")\"");
+ F.pp_print_space fmt ();
+ F.pp_print_string fmt ":=";
+ F.pp_print_space fmt ();
+ F.pp_print_string fmt "(";
+ F.pp_print_string fmt field_name;
+ F.pp_print_space fmt ();
+ F.pp_print_string fmt record_var;
+ F.pp_print_string fmt ")";
+ F.pp_print_space fmt ();
+ F.pp_print_string fmt "(at level 9)";
+ (* Close the inner box projector *)
+ F.pp_close_box fmt ();
+ (* If Coq: end the definition with a "." *)
+ if !backend = Coq then (
+ F.pp_print_cut fmt ();
+ F.pp_print_string fmt ".");
+ (* Close the outer box projector *)
+ F.pp_close_box fmt ();
+ (* Add breaks to insert new lines between definitions *)
+ F.pp_print_break fmt 0 0
+ in
+
+ let extract_field_proj_and_notation (field_id : FieldId.id)
+ (field : field) : unit =
+ extract_field_proj field_id field;
+ extract_proj_notation field_id field
+ in
+
+ FieldId.iteri extract_field_proj_and_notation fields
+
+(** Extract extra information for a type (e.g., [Arguments] instructions in Coq).
+
+ Note that all the names used for extraction should already have been
+ registered.
+ *)
+let extract_type_decl_extra_info (ctx : extraction_ctx) (fmt : F.formatter)
+ (kind : decl_kind) (decl : type_decl) : unit =
+ match !backend with
+ | FStar | Lean | HOL4 -> ()
+ | Coq ->
+ extract_type_decl_coq_arguments ctx fmt kind decl;
+ extract_type_decl_record_field_projectors ctx fmt kind decl
+
+(** Extract the state type declaration. *)
+let extract_state_type (fmt : F.formatter) (ctx : extraction_ctx)
+ (kind : decl_kind) : unit =
+ (* Add a break before *)
+ F.pp_print_break fmt 0 0;
+ (* Print a comment *)
+ extract_comment fmt [ "The state type used in the state-error monad" ];
+ F.pp_print_break fmt 0 0;
+ (* Open a box for the definition, so that whenever possible it gets printed on
+ * one line *)
+ F.pp_open_hvbox fmt 0;
+ (* Retrieve the name *)
+ let state_name = ctx_get_assumed_type State ctx in
+ (* The syntax for Lean and Coq is almost identical. *)
+ let print_axiom () =
+ let axiom =
+ match !backend with
+ | Coq -> "Axiom"
+ | Lean -> "axiom"
+ | FStar | HOL4 -> raise (Failure "Unexpected")
+ in
+ F.pp_print_string fmt axiom;
+ F.pp_print_space fmt ();
+ F.pp_print_string fmt state_name;
+ F.pp_print_space fmt ();
+ F.pp_print_string fmt ":";
+ F.pp_print_space fmt ();
+ F.pp_print_string fmt "Type";
+ if !backend = Coq then F.pp_print_string fmt "."
+ in
+ (* The kind should be [Assumed] or [Declared] *)
+ (match kind with
+ | SingleNonRec | SingleRec | MutRecFirst | MutRecInner | MutRecLast ->
+ raise (Failure "Unexpected")
+ | Assumed -> (
+ match !backend with
+ | FStar ->
+ F.pp_print_string fmt "assume";
+ F.pp_print_space fmt ();
+ F.pp_print_string fmt "type";
+ F.pp_print_space fmt ();
+ F.pp_print_string fmt state_name;
+ F.pp_print_space fmt ();
+ F.pp_print_string fmt ":";
+ F.pp_print_space fmt ();
+ F.pp_print_string fmt "Type0"
+ | HOL4 ->
+ F.pp_print_string fmt ("val _ = new_type (\"" ^ state_name ^ "\", 0)")
+ | Coq | Lean -> print_axiom ())
+ | Declared -> (
+ match !backend with
+ | FStar ->
+ F.pp_print_string fmt "val";
+ F.pp_print_space fmt ();
+ F.pp_print_string fmt state_name;
+ F.pp_print_space fmt ();
+ F.pp_print_string fmt ":";
+ F.pp_print_space fmt ();
+ F.pp_print_string fmt "Type0"
+ | HOL4 ->
+ F.pp_print_string fmt ("val _ = new_type (\"" ^ state_name ^ "\", 0)")
+ | Coq | Lean -> print_axiom ()));
+ (* Close the box for the definition *)
+ F.pp_close_box fmt ();
+ (* Add breaks to insert new lines between definitions *)
+ F.pp_print_break fmt 0 0
diff --git a/compiler/FunsAnalysis.ml b/compiler/FunsAnalysis.ml
index b72fa078..e17ea16f 100644
--- a/compiler/FunsAnalysis.ml
+++ b/compiler/FunsAnalysis.ml
@@ -57,12 +57,26 @@ let analyze_module (m : crate) (funs_map : fun_decl FunDeclId.Map.t)
let stateful = ref false in
let can_diverge = ref false in
let is_rec = ref false in
+ let group_has_builtin_info = ref false in
+
+ (* We have some specialized knowledge of some library functions; we don't
+ have any more custom treatment than this, and these functions can be modeled
+ suitably in Primitives.fst, rather than special-casing for them all the
+ way. *)
+ let get_builtin_info (f : fun_decl) : ExtractBuiltin.effect_info option =
+ let open ExtractBuiltin in
+ let name = name_to_simple_name f.name in
+ SimpleNameMap.find_opt name builtin_fun_effects_map
+ in
+ (* JP: Why not use a reduce visitor here with a tuple of the values to be
+ computed? *)
let visit_fun (f : fun_decl) : unit =
let obj =
object (self)
inherit [_] iter_statement as super
method may_fail b = can_fail := !can_fail || b
+ method maybe_stateful b = stateful := !stateful || b
method! visit_Assert env a =
self#may_fail true;
@@ -70,14 +84,14 @@ let analyze_module (m : crate) (funs_map : fun_decl FunDeclId.Map.t)
method! visit_rvalue _env rv =
match rv with
- | Use _ | Ref _ | Global _ | Discriminant _ | Aggregate _ -> ()
+ | Use _ | RvRef _ | Global _ | Discriminant _ | Aggregate _ -> ()
| UnaryOp (uop, _) -> can_fail := EU.unop_can_fail uop || !can_fail
| BinaryOp (bop, _, _) ->
can_fail := EU.binop_can_fail bop || !can_fail
method! visit_Call env call =
- (match call.func with
- | Regular id ->
+ (match call.func.func with
+ | FunId (Regular id) ->
if FunDeclId.Set.mem id fun_ids then (
can_diverge := true;
is_rec := true)
@@ -86,9 +100,14 @@ let analyze_module (m : crate) (funs_map : fun_decl FunDeclId.Map.t)
self#may_fail info.can_fail;
stateful := !stateful || info.stateful;
can_diverge := !can_diverge || info.can_diverge
- | Assumed id ->
+ | FunId (Assumed id) ->
(* None of the assumed functions can diverge nor are considered stateful *)
- can_fail := !can_fail || Assumed.assumed_can_fail id);
+ can_fail := !can_fail || Assumed.assumed_fun_can_fail id
+ | TraitMethod _ ->
+ (* We consider trait functions can fail, but can not diverge and are not stateful.
+ TODO: this may cause issues if we use use a fuel parameter.
+ *)
+ can_fail := true);
super#visit_Call env call
method! visit_Panic env =
@@ -102,11 +121,21 @@ let analyze_module (m : crate) (funs_map : fun_decl FunDeclId.Map.t)
in
(* Sanity check: global bodies don't contain stateful calls *)
assert ((not f.is_global_decl_body) || not !stateful);
+ let builtin_info = get_builtin_info f in
+ let has_builtin_info = builtin_info <> None in
+ group_has_builtin_info := !group_has_builtin_info || has_builtin_info;
match f.body with
| None ->
- (* Opaque function: we consider they fail by default *)
- obj#may_fail true;
- stateful := (not f.is_global_decl_body) && use_state
+ let info_can_fail, info_stateful =
+ match builtin_info with
+ | None -> (true, use_state)
+ | Some { can_fail; stateful } -> (can_fail, stateful)
+ in
+ obj#may_fail info_can_fail;
+ obj#maybe_stateful
+ (if f.is_global_decl_body then false
+ else if not use_state then false
+ else info_stateful)
| Some body -> obj#visit_statement () body.body
in
List.iter visit_fun d;
@@ -114,12 +143,17 @@ let analyze_module (m : crate) (funs_map : fun_decl FunDeclId.Map.t)
* groups containing globals contain exactly one declaration *)
let is_global_decl_body = List.exists (fun f -> f.is_global_decl_body) d in
assert ((not is_global_decl_body) || List.length d = 1);
+ assert ((not !group_has_builtin_info) || List.length d = 1);
(* We ignore on purpose functions that cannot fail and consider they *can*
* fail: the result of the analysis is not used yet to adjust the translation
* so that the functions which syntactically can't fail don't use an error monad.
- * However, we do keep the result of the analysis for global bodies.
+ * However, we do keep the result of the analysis for global bodies and for
+ * builtin functions which are marked as non-fallible.
* *)
- can_fail := (not is_global_decl_body) || !can_fail;
+ can_fail :=
+ if is_global_decl_body then !can_fail
+ else if !group_has_builtin_info then !can_fail
+ else true;
{
can_fail = !can_fail;
stateful = !stateful;
@@ -141,7 +175,8 @@ let analyze_module (m : crate) (funs_map : fun_decl FunDeclId.Map.t)
let rec analyze_decl_groups (decls : declaration_group list) : unit =
match decls with
| [] -> ()
- | Type _ :: decls' -> analyze_decl_groups decls'
+ | (Type _ | TraitDecl _ | TraitImpl _) :: decls' ->
+ analyze_decl_groups decls'
| Fun decl :: decls' ->
analyze_fun_decl_group decl;
analyze_decl_groups decls'
diff --git a/compiler/Interpreter.ml b/compiler/Interpreter.ml
index 154c5a21..24ff4808 100644
--- a/compiler/Interpreter.ml
+++ b/compiler/Interpreter.ml
@@ -12,55 +12,165 @@ module SA = SymbolicAst
(** The local logger *)
let log = L.interpreter_log
-let compute_type_fun_global_contexts (m : A.crate) :
- C.type_context * C.fun_context * C.global_context =
- let type_decls_list, _, _ = split_declarations m.declarations in
+let compute_contexts (m : A.crate) : C.decls_ctx =
+ let type_decls_list, _, _, _, _ = split_declarations m.declarations in
let type_decls = m.types in
let fun_decls = m.functions in
let global_decls = m.globals in
- let type_decls_groups, _funs_defs_groups, _globals_defs_groups =
+ let trait_decls = m.trait_decls in
+ let trait_impls = m.trait_impls in
+ let type_decls_groups, _, _, _, _ =
split_declarations_to_group_maps m.declarations
in
let type_infos =
TypesAnalysis.analyze_type_declarations type_decls type_decls_list
in
- let type_context = { C.type_decls_groups; type_decls; type_infos } in
- let fun_context = { C.fun_decls } in
- let global_context = { C.global_decls } in
- (type_context, fun_context, global_context)
-
-let initialize_eval_context (type_context : C.type_context)
- (fun_context : C.fun_context) (global_context : C.global_context)
- (region_groups : T.RegionGroupId.id list) (type_vars : T.type_var list)
- (const_generic_vars : T.const_generic_var list) : C.eval_ctx =
- C.reset_global_counters ();
- {
- C.type_context;
- C.fun_context;
- C.global_context;
- C.region_groups;
- C.type_vars;
- C.const_generic_vars;
- C.env = [ C.Frame ];
- C.ended_regions = T.RegionId.Set.empty;
- }
+ let type_ctx = { C.type_decls_groups; type_decls; type_infos } in
+ let fun_infos =
+ FunsAnalysis.analyze_module m fun_decls global_decls !Config.use_state
+ in
+ let fun_ctx = { C.fun_decls; fun_infos } in
+ let global_ctx = { C.global_decls } in
+ let trait_decls_ctx = { C.trait_decls } in
+ let trait_impls_ctx = { C.trait_impls } in
+ { C.type_ctx; fun_ctx; global_ctx; trait_decls_ctx; trait_impls_ctx }
+
+(** Small helper.
+
+ Normalize an instantiated function signature provided we used this signature
+ to compute a normalization map (for the associated types) and that we added
+ it in the context.
+ *)
+let normalize_inst_fun_sig (ctx : C.eval_ctx) (sg : A.inst_fun_sig) :
+ A.inst_fun_sig =
+ let { A.regions_hierarchy = _; trait_type_constraints = _; inputs; output } =
+ sg
+ in
+ let norm = AssociatedTypes.ctx_normalize_rty ctx in
+ let inputs = List.map norm inputs in
+ let output = norm output in
+ { sg with A.inputs; output }
+
+(** Instantiate a function signature for a symbolic execution.
+
+ We return a new context because we compute and add the type normalization
+ map in the same step.
+
+ **WARNING**: this doesn't normalize the types. This step has to be done
+ separately. Remark: we need to normalize essentially because of the where
+ clauses (we are not considering a function call, so we don't need to
+ normalize because a trait clause was instantiated with a specific trait ref).
+ *)
+let symbolic_instantiate_fun_sig (ctx : C.eval_ctx) (sg : A.fun_sig)
+ (kind : A.fun_kind) : C.eval_ctx * A.inst_fun_sig =
+ let tr_self =
+ match kind with
+ | RegularKind | TraitMethodImpl _ -> T.UnknownTrait __FUNCTION__
+ | TraitMethodDecl _ | TraitMethodProvided _ -> T.Self
+ in
+ let generics =
+ let { T.regions; types; const_generics; trait_clauses } = sg.generics in
+ let regions = List.map (fun _ -> T.Erased) regions in
+ let types = List.map (fun (v : T.type_var) -> T.TypeVar v.T.index) types in
+ let const_generics =
+ List.map
+ (fun (v : T.const_generic_var) -> T.ConstGenericVar v.T.index)
+ const_generics
+ in
+ (* Annoying that we have to generate this substitution here *)
+ let r_subst _ = raise (Failure "Unexpected region") in
+ let ty_subst = Subst.make_type_subst_from_vars sg.generics.types types in
+ let cg_subst =
+ Subst.make_const_generic_subst_from_vars sg.generics.const_generics
+ const_generics
+ in
+ (* TODO: some clauses may use the types of other clauses, so we may have to
+ reorder them.
+
+ Example:
+ If in Rust we write:
+ {[
+ pub fn use_get<'a, T: Get>(x: &'a mut T) -> u32
+ where
+ T::Item: ToU32,
+ {
+ x.get().to_u32()
+ }
+ ]}
+
+ In LLBC we get:
+ {[
+ fn demo::use_get<'a, T>(@1: &'a mut (T)) -> u32
+ where
+ [@TraitClause0]: demo::Get<T>,
+ [@TraitClause1]: demo::ToU32<@TraitClause0::Item>, // HERE
+ {
+ ... // Omitted
+ }
+ ]}
+ *)
+ (* We will need to update the trait refs map while we perform the instantiations *)
+ let mk_tr_subst
+ (tr_map : T.erased_region T.trait_instance_id T.TraitClauseId.Map.t)
+ clause_id : T.erased_region T.trait_instance_id =
+ match T.TraitClauseId.Map.find_opt clause_id tr_map with
+ | Some tr -> tr
+ | None -> raise (Failure "Local trait clause not found")
+ in
+ let mk_subst tr_map =
+ let tr_subst = mk_tr_subst tr_map in
+ { Subst.r_subst; ty_subst; cg_subst; tr_subst; tr_self }
+ in
+ let _, trait_refs =
+ List.fold_left_map
+ (fun tr_map (c : T.trait_clause) ->
+ let subst = mk_subst tr_map in
+ let { T.trait_id = trait_decl_id; generics; _ } = c in
+ let generics = Subst.generic_args_substitute subst generics in
+ let trait_decl_ref = { T.trait_decl_id; decl_generics = generics } in
+ (* Note that because we directly refer to the clause, we give it
+ empty generics *)
+ let trait_id = T.Clause c.clause_id in
+ let trait_ref =
+ {
+ T.trait_id;
+ generics = TypesUtils.mk_empty_generic_args;
+ trait_decl_ref;
+ }
+ in
+ (* Update the traits map *)
+ let tr_map = T.TraitClauseId.Map.add c.T.clause_id trait_id tr_map in
+ (tr_map, trait_ref))
+ T.TraitClauseId.Map.empty trait_clauses
+ in
+ { T.regions; types; const_generics; trait_refs }
+ in
+ let inst_sg = instantiate_fun_sig ctx generics tr_self sg in
+ (* Compute the normalization maps *)
+ let ctx =
+ AssociatedTypes.ctx_add_norm_trait_types_from_preds ctx
+ inst_sg.trait_type_constraints
+ in
+ (* Normalize the signature *)
+ let inst_sg = normalize_inst_fun_sig ctx inst_sg in
+ (* Return *)
+ (ctx, inst_sg)
(** Initialize an evaluation context to execute a function.
- Introduces local variables initialized in the following manner:
- - input arguments are initialized as symbolic values
- - the remaining locals are initialized as [⊥]
- Abstractions are introduced for the regions present in the function
- signature.
-
- We return:
- - the initialized evaluation context
- - the list of symbolic values introduced for the input values
- - the instantiated function signature
+ Introduces local variables initialized in the following manner:
+ - input arguments are initialized as symbolic values
+ - the remaining locals are initialized as [⊥]
+ Abstractions are introduced for the regions present in the function
+ signature.
+
+ We return:
+ - the initialized evaluation context
+ - the list of symbolic values introduced for the input values
+ - the instantiated function signature
*)
-let initialize_symbolic_context_for_fun (type_context : C.type_context)
- (fun_context : C.fun_context) (global_context : C.global_context)
- (fdef : A.fun_decl) : C.eval_ctx * V.symbolic_value list * A.inst_fun_sig =
+let initialize_symbolic_context_for_fun (ctx : C.decls_ctx) (fdef : A.fun_decl)
+ : C.eval_ctx * V.symbolic_value list * A.inst_fun_sig =
(* The abstractions are not initialized the same way as for function
* calls: they contain *loan* projectors, because they "provide" us
* with the input values (which behave as if they had been returned
@@ -78,19 +188,15 @@ let initialize_symbolic_context_for_fun (type_context : C.type_context)
List.map (fun (g : T.region_var_group) -> g.id) sg.regions_hierarchy
in
let ctx =
- initialize_eval_context type_context fun_context global_context
- region_groups sg.type_params sg.const_generic_params
+ initialize_eval_context ctx region_groups sg.generics.types
+ sg.generics.const_generics
in
- (* Instantiate the signature *)
- let type_params =
- List.map (fun (v : T.type_var) -> T.TypeVar v.T.index) sg.type_params
+ (* Instantiate the signature. This updates the context because we compute
+ at the same time the normalization map for the associated types.
+ *)
+ let ctx, inst_sg =
+ symbolic_instantiate_fun_sig ctx fdef.signature fdef.kind
in
- let cg_params =
- List.map
- (fun (v : T.const_generic_var) -> T.ConstGenericVar v.T.index)
- sg.const_generic_params
- in
- let inst_sg = instantiate_fun_sig type_params cg_params sg in
(* Create fresh symbolic values for the inputs *)
let input_svs =
List.map (fun ty -> mk_fresh_symbolic_value V.SynthInput ty) inst_sg.inputs
@@ -165,15 +271,9 @@ let evaluate_function_symbolic_synthesize_backward_from_return
* an instantiation of the signature, so that we use fresh
* region ids for the return abstractions. *)
let sg = fdef.signature in
- let type_params =
- List.map (fun (v : T.type_var) -> T.TypeVar v.T.index) sg.type_params
- in
- let cg_params =
- List.map
- (fun (v : T.const_generic_var) -> T.ConstGenericVar v.T.index)
- sg.const_generic_params
+ let _, ret_inst_sg =
+ symbolic_instantiate_fun_sig ctx fdef.signature fdef.kind
in
- let ret_inst_sg = instantiate_fun_sig type_params cg_params sg in
let ret_rty = ret_inst_sg.output in
(* Move the return value out of the return variable *)
let pop_return_value = is_regular_return in
@@ -347,19 +447,14 @@ let evaluate_function_symbolic_synthesize_backward_from_return
for the synthesis)
- the symbolic AST generated by the symbolic execution
*)
-let evaluate_function_symbolic (synthesize : bool)
- (type_context : C.type_context) (fun_context : C.fun_context)
- (global_context : C.global_context) (fdef : A.fun_decl) :
- V.symbolic_value list * SA.expression option =
+let evaluate_function_symbolic (synthesize : bool) (ctx : C.decls_ctx)
+ (fdef : A.fun_decl) : V.symbolic_value list * SA.expression option =
(* Debug *)
let name_to_string () = Print.fun_name_to_string fdef.A.name in
log#ldebug (lazy ("evaluate_function_symbolic: " ^ name_to_string ()));
(* Create the evaluation context *)
- let ctx, input_svs, inst_sg =
- initialize_symbolic_context_for_fun type_context fun_context global_context
- fdef
- in
+ let ctx, input_svs, inst_sg = initialize_symbolic_context_for_fun ctx fdef in
(* Create the continuation to finish the evaluation *)
let config = C.mk_config C.SymbolicMode in
@@ -488,7 +583,8 @@ module Test = struct
(** Test a unit function (taking no arguments) by evaluating it in an empty
environment.
*)
- let test_unit_function (crate : A.crate) (fid : A.FunDeclId.id) : unit =
+ let test_unit_function (crate : A.crate) (decls_ctx : C.decls_ctx)
+ (fid : A.FunDeclId.id) : unit =
(* Retrieve the function declaration *)
let fdef = A.FunDeclId.Map.find fid crate.functions in
let body = Option.get fdef.body in
@@ -498,17 +594,11 @@ module Test = struct
(lazy ("test_unit_function: " ^ Print.fun_name_to_string fdef.A.name));
(* Sanity check - *)
- assert (List.length fdef.A.signature.region_params = 0);
- assert (List.length fdef.A.signature.type_params = 0);
+ assert (fdef.A.signature.generics = TypesUtils.mk_empty_generic_params);
assert (body.A.arg_count = 0);
(* Create the evaluation context *)
- let type_context, fun_context, global_context =
- compute_type_fun_global_contexts crate
- in
- let ctx =
- initialize_eval_context type_context fun_context global_context [] [] []
- in
+ let ctx = initialize_eval_context decls_ctx [] [] [] in
(* Insert the (uninitialized) local variables *)
let ctx = C.ctx_push_uninitialized_vars ctx body.A.locals in
@@ -536,9 +626,7 @@ module Test = struct
(no parameters, no arguments) - TODO: move *)
let fun_decl_is_transparent_unit (def : A.fun_decl) : bool =
Option.is_some def.body
- && def.A.signature.region_params = []
- && def.A.signature.type_params = []
- && def.A.signature.const_generic_params = []
+ && def.A.signature.generics = TypesUtils.mk_empty_generic_params
&& def.A.signature.inputs = []
(** Test all the unit functions in a list of function definitions *)
@@ -548,24 +636,9 @@ module Test = struct
(fun _ -> fun_decl_is_transparent_unit)
crate.functions
in
+ let decls_ctx = compute_contexts crate in
let test_unit_fun _ (def : A.fun_decl) : unit =
- test_unit_function crate def.A.def_id
+ test_unit_function crate decls_ctx def.A.def_id
in
A.FunDeclId.Map.iter test_unit_fun unit_funs
-
- (** Execute the symbolic interpreter on a function. *)
- let test_function_symbolic (synthesize : bool) (type_context : C.type_context)
- (fun_context : C.fun_context) (global_context : C.global_context)
- (fdef : A.fun_decl) : unit =
- (* Debug *)
- log#ldebug
- (lazy ("test_function_symbolic: " ^ Print.fun_name_to_string fdef.A.name));
-
- (* Evaluate *)
- let _ =
- evaluate_function_symbolic synthesize type_context fun_context
- global_context fdef
- in
-
- ()
end
diff --git a/compiler/InterpreterBorrows.ml b/compiler/InterpreterBorrows.ml
index 4d67a4e4..e97795a1 100644
--- a/compiler/InterpreterBorrows.ml
+++ b/compiler/InterpreterBorrows.ml
@@ -452,7 +452,8 @@ let give_back_symbolic_value (_config : C.config)
| V.SynthInputGivenBack | SynthRetGivenBack | FunCallGivenBack | LoopGivenBack
->
()
- | FunCallRet | SynthInput | Global | LoopOutput | LoopJoin | Aggregate ->
+ | FunCallRet | SynthInput | Global | LoopOutput | LoopJoin | Aggregate
+ | ConstGeneric | TraitConst ->
raise (Failure "Unreachable"));
(* Store the given-back value as a meta-value for synthesis purposes *)
let mv = nsv in
diff --git a/compiler/InterpreterBorrowsCore.ml b/compiler/InterpreterBorrowsCore.ml
index bf083aa4..e7da045c 100644
--- a/compiler/InterpreterBorrowsCore.ml
+++ b/compiler/InterpreterBorrowsCore.ml
@@ -100,15 +100,18 @@ let rec compare_rtys (default : bool) (combine : bool -> bool -> bool)
(compare_regions : T.RegionId.id T.region -> T.RegionId.id T.region -> bool)
(ty1 : T.rty) (ty2 : T.rty) : bool =
let compare = compare_rtys default combine compare_regions in
+ (* Normalize the associated types *)
match (ty1, ty2) with
| T.Literal lit1, T.Literal lit2 ->
assert (lit1 = lit2);
default
- | T.Adt (id1, regions1, tys1, cgs1), T.Adt (id2, regions2, tys2, cgs2) ->
+ | T.Adt (id1, generics1), T.Adt (id2, generics2) ->
assert (id1 = id2);
(* There are no regions in the const generics, so we ignore them,
but we still check they are the same, for sanity *)
- assert (cgs1 = cgs2);
+ assert (generics1.const_generics = generics2.const_generics);
+
+ (* We also ignore the trait refs *)
(* The check for the ADTs is very crude: we simply compare the arguments
* two by two.
@@ -123,14 +126,14 @@ let rec compare_rtys (default : bool) (combine : bool -> bool -> bool)
* this check would still be a reasonable conservative approximation. *)
(* Check the region parameters *)
- let regions = List.combine regions1 regions2 in
+ let regions = List.combine generics1.regions generics2.regions in
let params_b =
List.fold_left
(fun b (r1, r2) -> combine b (compare_regions r1 r2))
default regions
in
(* Check the type parameters *)
- let tys = List.combine tys1 tys2 in
+ let tys = List.combine generics1.types generics2.types in
let tys_b =
List.fold_left
(fun b (ty1, ty2) -> combine b (compare ty1 ty2))
@@ -150,6 +153,11 @@ let rec compare_rtys (default : bool) (combine : bool -> bool -> bool)
| T.TypeVar id1, T.TypeVar id2 ->
assert (id1 = id2);
default
+ | T.TraitType _, T.TraitType _ ->
+ (* The types should have been normalized. If after normalization we
+ get trait types, we can consider them as variables *)
+ assert (ty1 = ty2);
+ default
| _ ->
log#lerror
(lazy
diff --git a/compiler/InterpreterExpansion.ml b/compiler/InterpreterExpansion.ml
index 81e73e3e..b267bb51 100644
--- a/compiler/InterpreterExpansion.ml
+++ b/compiler/InterpreterExpansion.ml
@@ -9,6 +9,7 @@ module V = Values
module E = Expressions
module C = Contexts
module Subst = Substitute
+module Assoc = AssociatedTypes
module L = Logging
open TypesUtils
module Inv = Invariants
@@ -204,7 +205,7 @@ let apply_symbolic_expansion_non_borrow (config : C.config)
apply_symbolic_expansion_to_avalues config allow_reborrows original_sv
expansion ctx
-(** Compute the expansion of a non-assumed (i.e.: not [Option], [Box], etc.)
+(** Compute the expansion of a non-assumed (i.e.: not [Box], etc.)
adt value.
The function might return a list of values if the symbolic value to expand
@@ -214,18 +215,15 @@ let apply_symbolic_expansion_non_borrow (config : C.config)
doesn't allow the expansion of enumerations *containing several variants*.
*)
let compute_expanded_symbolic_non_assumed_adt_value (expand_enumerations : bool)
- (kind : V.sv_kind) (def_id : T.TypeDeclId.id)
- (regions : T.RegionId.id T.region list) (types : T.rty list)
- (cgs : T.const_generic list) (ctx : C.eval_ctx) : V.symbolic_expansion list
- =
+ (kind : V.sv_kind) (def_id : T.TypeDeclId.id) (generics : T.rgeneric_args)
+ (ctx : C.eval_ctx) : V.symbolic_expansion list =
(* Lookup the definition and check if it is an enumeration with several
* variants *)
let def = C.ctx_lookup_type_decl ctx def_id in
- assert (List.length regions = List.length def.T.region_params);
+ assert (List.length generics.regions = List.length def.T.generics.regions);
(* Retrieve, for every variant, the list of its instantiated field types *)
let variants_fields_types =
- Subst.type_decl_get_instantiated_variants_fields_rtypes def regions types
- cgs
+ Assoc.type_decl_get_inst_norm_variants_fields_rtypes ctx def generics
in
(* Check if there is strictly more than one variant *)
if List.length variants_fields_types > 1 && not expand_enumerations then
@@ -243,17 +241,6 @@ let compute_expanded_symbolic_non_assumed_adt_value (expand_enumerations : bool)
(* Initialize all the expanded values of all the variants *)
List.map initialize variants_fields_types
-(** Compute the expansion of an Option value.
- *)
-let compute_expanded_symbolic_option_value (expand_enumerations : bool)
- (kind : V.sv_kind) (ty : T.rty) : V.symbolic_expansion list =
- assert expand_enumerations;
- let some_se =
- V.SeAdt (Some T.option_some_id, [ mk_fresh_symbolic_value kind ty ])
- in
- let none_se = V.SeAdt (Some T.option_none_id, []) in
- [ none_se; some_se ]
-
let compute_expanded_symbolic_tuple_value (kind : V.sv_kind)
(field_types : T.rty list) : V.symbolic_expansion =
(* Generate the field values *)
@@ -280,17 +267,14 @@ let compute_expanded_symbolic_box_value (kind : V.sv_kind) (boxed_ty : T.rty) :
doesn't allow the expansion of enumerations *containing several variants*.
*)
let compute_expanded_symbolic_adt_value (expand_enumerations : bool)
- (kind : V.sv_kind) (adt_id : T.type_id)
- (regions : T.RegionId.id T.region list) (types : T.rty list)
- (cgs : T.const_generic list) (ctx : C.eval_ctx) : V.symbolic_expansion list
- =
- match (adt_id, regions, types) with
+ (kind : V.sv_kind) (adt_id : T.type_id) (generics : T.rgeneric_args)
+ (ctx : C.eval_ctx) : V.symbolic_expansion list =
+ match (adt_id, generics.regions, generics.types) with
| T.AdtId def_id, _, _ ->
compute_expanded_symbolic_non_assumed_adt_value expand_enumerations kind
- def_id regions types cgs ctx
- | T.Tuple, [], _ -> [ compute_expanded_symbolic_tuple_value kind types ]
- | T.Assumed T.Option, [], [ ty ] ->
- compute_expanded_symbolic_option_value expand_enumerations kind ty
+ def_id generics ctx
+ | T.Tuple, [], _ ->
+ [ compute_expanded_symbolic_tuple_value kind generics.types ]
| T.Assumed T.Box, [], [ boxed_ty ] ->
[ compute_expanded_symbolic_box_value kind boxed_ty ]
| _ ->
@@ -543,12 +527,12 @@ let expand_symbolic_value_no_branching (config : C.config)
fun cf ctx ->
match rty with
(* ADTs *)
- | T.Adt (adt_id, regions, types, cgs) ->
+ | T.Adt (adt_id, generics) ->
(* Compute the expanded value *)
let allow_branching = false in
let seel =
compute_expanded_symbolic_adt_value allow_branching sv.sv_kind adt_id
- regions types cgs ctx
+ generics ctx
in
(* There should be exacly one branch *)
let see = Collections.List.to_cons_nil seel in
@@ -600,12 +584,12 @@ let expand_symbolic_adt (config : C.config) (sv : V.symbolic_value)
(* Execute *)
match rty with
(* ADTs *)
- | T.Adt (adt_id, regions, types, cgs) ->
+ | T.Adt (adt_id, generics) ->
let allow_branching = true in
(* Compute the expanded value *)
let seel =
compute_expanded_symbolic_adt_value allow_branching sv.sv_kind adt_id
- regions types cgs ctx
+ generics ctx
in
(* Apply *)
let seel = List.map (fun see -> (Some see, cf_branches)) seel in
@@ -679,7 +663,7 @@ let greedy_expand_symbolics_with_borrows (config : C.config) : cm_fun =
^ symbolic_value_to_string ctx sv));
let cc : cm_fun =
match sv.V.sv_ty with
- | T.Adt (AdtId def_id, _, _, _) ->
+ | T.Adt (AdtId def_id, _) ->
(* {!expand_symbolic_value_no_branching} checks if there are branchings,
* but we prefer to also check it here - this leads to cleaner messages
* and debugging *)
@@ -704,16 +688,17 @@ let greedy_expand_symbolics_with_borrows (config : C.config) : cm_fun =
[config]): "
^ Print.name_to_string def.name))
else expand_symbolic_value_no_branching config sv None
- | T.Adt ((Tuple | Assumed Box), _, _, _) | T.Ref (_, _, _) ->
+ | T.Adt ((Tuple | Assumed Box), _) | T.Ref (_, _, _) ->
(* Ok *)
expand_symbolic_value_no_branching config sv None
- | T.Adt (Assumed (Vec | Option | Array | Slice | Str | Range), _, _, _)
- ->
+ | T.Adt (Assumed (Array | Slice | Str), _) ->
(* We can't expand those *)
raise
(Failure
"Attempted to greedily expand an ADT which can't be expanded ")
- | T.TypeVar _ | T.Literal _ | Never -> raise (Failure "Unreachable")
+ | T.TypeVar _ | T.Literal _ | Never | T.TraitType _ | T.Arrow _
+ | T.RawPtr _ ->
+ raise (Failure "Unreachable")
in
(* Compose and continue *)
comp cc expand cf ctx
diff --git a/compiler/InterpreterExpressions.ml b/compiler/InterpreterExpressions.ml
index 8b2070c6..245f3b77 100644
--- a/compiler/InterpreterExpressions.ml
+++ b/compiler/InterpreterExpressions.ml
@@ -7,6 +7,7 @@ module E = Expressions
open Utils
module C = Contexts
module Subst = Substitute
+module Assoc = AssociatedTypes
module L = Logging
open TypesUtils
open ValuesUtils
@@ -141,11 +142,19 @@ let rec copy_value (allow_adt_copy : bool) (config : C.config)
| V.Adt av ->
(* Sanity check *)
(match v.V.ty with
- | T.Adt (T.Assumed (T.Box | Vec), _, _, _) ->
+ | T.Adt (T.Assumed T.Box, _) ->
raise (Failure "Can't copy an assumed value other than Option")
- | T.Adt (T.AdtId _, _, _, _) -> assert allow_adt_copy
- | T.Adt ((T.Assumed Option | T.Tuple), _, _, _) -> () (* Ok *)
- | T.Adt (T.Assumed (Slice | T.Array), [], [ ty ], []) ->
+ | T.Adt (T.AdtId _, _) as ty ->
+ assert (allow_adt_copy || ty_is_primitively_copyable ty)
+ | T.Adt (T.Tuple, _) -> () (* Ok *)
+ | T.Adt
+ ( T.Assumed (Slice | T.Array),
+ {
+ regions = [];
+ types = [ ty ];
+ const_generics = [];
+ trait_refs = [];
+ } ) ->
assert (ty_is_primitively_copyable ty)
| _ -> raise (Failure "Unreachable"));
let ctx, fields =
@@ -230,17 +239,16 @@ let prepare_eval_operand_reorganize (config : C.config) (op : E.operand) :
let prepare : cm_fun =
fun cf ctx ->
match op with
- | Expressions.Constant (ty, cv) ->
+ | E.Constant _ ->
(* No need to reorganize the context *)
- literal_to_typed_value (TypesUtils.ty_as_literal ty) cv |> ignore;
cf ctx
- | Expressions.Copy p ->
+ | E.Copy p ->
(* Access the value *)
let access = Read in
(* Expand the symbolic values, if necessary *)
let expand_prim_copy = true in
access_rplace_reorganize config expand_prim_copy access p cf ctx
- | Expressions.Move p ->
+ | E.Move p ->
(* Access the value *)
let access = Move in
let expand_prim_copy = false in
@@ -260,9 +268,71 @@ let eval_operand_no_reorganize (config : C.config) (op : E.operand)
^ "\n- ctx:\n" ^ eval_ctx_to_string ctx ^ "\n"));
(* Evaluate *)
match op with
- | Expressions.Constant (ty, cv) ->
- cf (literal_to_typed_value (TypesUtils.ty_as_literal ty) cv) ctx
- | Expressions.Copy p ->
+ | E.Constant cv -> (
+ match cv.value with
+ | E.CLiteral lit ->
+ cf (literal_to_typed_value (TypesUtils.ty_as_literal cv.ty) lit) ctx
+ | E.CTraitConst (trait_ref, generics, const_name) -> (
+ assert (generics = TypesUtils.mk_empty_generic_args);
+ match trait_ref.trait_id with
+ | T.TraitImpl _ ->
+ (* This shouldn't happen: if we refer to a concrete implementation, we
+ should directly refer to the top-level constant *)
+ raise (Failure "Unreachable")
+ | _ -> (
+ (* We refer to a constant defined in a local clause: simply
+ introduce a fresh symbolic value *)
+ let ctx0 = ctx in
+ (* Lookup the trait declaration to retrieve the type of the symbolic value *)
+ let trait_decl =
+ C.ctx_lookup_trait_decl ctx
+ trait_ref.trait_decl_ref.trait_decl_id
+ in
+ let _, (ty, _) =
+ List.find (fun (name, _) -> name = const_name) trait_decl.consts
+ in
+ (* Introduce a fresh symbolic value *)
+ let v = mk_fresh_symbolic_typed_value_from_ety V.TraitConst ty in
+ (* Continue the evaluation *)
+ let e = cf v ctx in
+ (* We have to wrap the generated expression *)
+ match e with
+ | None -> None
+ | Some e ->
+ Some
+ (SymbolicAst.IntroSymbolic
+ ( ctx0,
+ None,
+ value_as_symbolic v.value,
+ SymbolicAst.TraitConstValue
+ (trait_ref, generics, const_name),
+ e ))))
+ | E.CVar vid -> (
+ let ctx0 = ctx in
+ (* Lookup the const generic value *)
+ let cv = C.ctx_lookup_const_generic_value ctx vid in
+ (* Copy the value *)
+ let allow_adt_copy = false in
+ let ctx, v = copy_value allow_adt_copy config ctx cv in
+ (* Continue *)
+ let e = cf v ctx in
+ (* We have to wrap the generated expression *)
+ match e with
+ | None -> None
+ | Some e ->
+ (* If we are synthesizing a symbolic AST, it means that we are in symbolic
+ mode: the value of the const generic is necessarily symbolic. *)
+ assert (is_symbolic cv.V.value);
+ (* *)
+ Some
+ (SymbolicAst.IntroSymbolic
+ ( ctx0,
+ None,
+ value_as_symbolic v.value,
+ SymbolicAst.ConstGenericValue vid,
+ e )))
+ | E.CFnPtr _ -> raise (Failure "TODO"))
+ | E.Copy p ->
(* Access the value *)
let access = Read in
let cc = read_place access p in
@@ -283,7 +353,7 @@ let eval_operand_no_reorganize (config : C.config) (op : E.operand)
in
(* Compose and apply *)
comp cc copy cf ctx
- | Expressions.Move p ->
+ | E.Move p ->
(* Access the value *)
let access = Move in
let cc = read_place access p in
@@ -358,7 +428,7 @@ let eval_unary_op_concrete (config : C.config) (unop : E.unop) (op : E.operand)
match mk_scalar sv.int_ty i with
| Error _ -> cf (Error EPanic)
| Ok sv -> cf (Ok { v with V.value = V.Literal (PV.Scalar sv) }))
- | E.Cast (src_ty, tgt_ty), V.Literal (PV.Scalar sv) -> (
+ | E.Cast (E.CastInteger (src_ty, tgt_ty)), V.Literal (PV.Scalar sv) -> (
assert (src_ty = sv.int_ty);
let i = sv.PV.value in
match mk_scalar tgt_ty i with
@@ -384,7 +454,7 @@ let eval_unary_op_symbolic (config : C.config) (unop : E.unop) (op : E.operand)
match (unop, v.V.ty) with
| E.Not, (T.Literal Bool as lty) -> lty
| E.Neg, (T.Literal (Integer _) as lty) -> lty
- | E.Cast (_, tgt_ty), _ -> T.Literal (Integer tgt_ty)
+ | E.Cast (E.CastInteger (_, tgt_ty)), _ -> T.Literal (Integer tgt_ty)
| _ -> raise (Failure "Invalid input for unop")
in
let res_sv =
@@ -653,73 +723,46 @@ let eval_rvalue_aggregate (config : C.config)
fun ctx ->
(* Match on the aggregate kind *)
match aggregate_kind with
- | E.AggregatedTuple ->
- let tys = List.map (fun (v : V.typed_value) -> v.V.ty) values in
- let v = V.Adt { variant_id = None; field_values = values } in
- let ty = T.Adt (T.Tuple, [], tys, []) in
- let aggregated : V.typed_value = { V.value = v; ty } in
- (* Call the continuation *)
- cf aggregated ctx
- | E.AggregatedOption (variant_id, ty) ->
- (* Sanity check *)
- if variant_id = T.option_none_id then assert (values = [])
- else if variant_id = T.option_some_id then
- assert (List.length values = 1)
- else raise (Failure "Unreachable");
- (* Construt the value *)
- let aty = T.Adt (T.Assumed T.Option, [], [ ty ], []) in
- let av : V.adt_value =
- { V.variant_id = Some variant_id; V.field_values = values }
- in
- let aggregated : V.typed_value = { V.value = Adt av; ty = aty } in
- (* Call the continuation *)
- cf aggregated ctx
- | E.AggregatedAdt (def_id, opt_variant_id, regions, types, cgs) ->
- (* Sanity checks *)
- let type_decl = C.ctx_lookup_type_decl ctx def_id in
- assert (List.length type_decl.region_params = List.length regions);
- let expected_field_types =
- Subst.ctx_adt_get_instantiated_field_etypes ctx def_id opt_variant_id
- types cgs
- in
- assert (
- expected_field_types
- = List.map (fun (v : V.typed_value) -> v.V.ty) values);
- (* Construct the value *)
- let av : V.adt_value =
- { V.variant_id = opt_variant_id; V.field_values = values }
- in
- let aty = T.Adt (T.AdtId def_id, regions, types, cgs) in
- let aggregated : V.typed_value = { V.value = Adt av; ty = aty } in
- (* Call the continuation *)
- cf aggregated ctx
- | E.AggregatedRange ety ->
- (* There should be two fields exactly *)
- let v0, v1 =
- match values with
- | [ v0; v1 ] -> (v0, v1)
- | _ -> raise (Failure "Unreachable")
- in
- (* Ranges are parametric over the type of indices. For now we only
- support scalars, which can be of any type *)
- assert (literal_type_is_integer (ty_as_literal ety));
- assert (v0.ty = ety);
- assert (v1.ty = ety);
- (* Construct the value *)
- let av : V.adt_value =
- { V.variant_id = None; V.field_values = values }
- in
- let aty = T.Adt (T.Assumed T.Range, [], [ ety ], []) in
- let aggregated : V.typed_value = { V.value = Adt av; ty = aty } in
- (* Call the continuation *)
- cf aggregated ctx
+ | E.AggregatedAdt (type_id, opt_variant_id, generics) -> (
+ match type_id with
+ | Tuple ->
+ let tys = List.map (fun (v : V.typed_value) -> v.V.ty) values in
+ let v = V.Adt { variant_id = None; field_values = values } in
+ let generics = TypesUtils.mk_generic_args [] tys [] [] in
+ let ty = T.Adt (T.Tuple, generics) in
+ let aggregated : V.typed_value = { V.value = v; ty } in
+ (* Call the continuation *)
+ cf aggregated ctx
+ | AdtId def_id ->
+ (* Sanity checks *)
+ let type_decl = C.ctx_lookup_type_decl ctx def_id in
+ assert (
+ List.length type_decl.generics.regions
+ = List.length generics.regions);
+ let expected_field_types =
+ Assoc.ctx_adt_get_inst_norm_field_etypes ctx def_id opt_variant_id
+ generics
+ in
+ assert (
+ expected_field_types
+ = List.map (fun (v : V.typed_value) -> v.V.ty) values);
+ (* Construct the value *)
+ let av : V.adt_value =
+ { V.variant_id = opt_variant_id; V.field_values = values }
+ in
+ let aty = T.Adt (T.AdtId def_id, generics) in
+ let aggregated : V.typed_value = { V.value = Adt av; ty = aty } in
+ (* Call the continuation *)
+ cf aggregated ctx
+ | Assumed _ -> raise (Failure "Unreachable"))
| E.AggregatedArray (ety, cg) -> (
(* Sanity check: all the values have the proper type *)
assert (List.for_all (fun (v : V.typed_value) -> v.V.ty = ety) values);
(* Sanity check: the number of values is consistent with the length *)
let len = (literal_as_scalar (const_generic_as_literal cg)).value in
assert (len = Z.of_int (List.length values));
- let ty = T.Adt (T.Assumed T.Array, [], [ ety ], [ cg ]) in
+ let generics = TypesUtils.mk_generic_args [] [ ety ] [ cg ] [] in
+ let ty = T.Adt (T.Assumed T.Array, generics) in
(* In order to generate a better AST, we introduce a symbolic
value equal to the array. The reason is that otherwise, the
array we introduce here might be duplicated in the generated
@@ -752,7 +795,7 @@ let eval_rvalue_not_global (config : C.config) (rvalue : E.rvalue)
(* Delegate to the proper auxiliary function *)
match rvalue with
| E.Use op -> comp_wrap (eval_operand config op) ctx
- | E.Ref (p, bkind) -> comp_wrap (eval_rvalue_ref config p bkind) ctx
+ | E.RvRef (p, bkind) -> comp_wrap (eval_rvalue_ref config p bkind) ctx
| E.UnaryOp (unop, op) -> eval_unary_op config unop op cf ctx
| E.BinaryOp (binop, op1, op2) -> eval_binary_op config binop op1 op2 cf ctx
| E.Aggregate (aggregate_kind, ops) ->
diff --git a/compiler/InterpreterLoopsJoinCtxs.ml b/compiler/InterpreterLoopsJoinCtxs.ml
index bf88e055..6d3ecb18 100644
--- a/compiler/InterpreterLoopsJoinCtxs.ml
+++ b/compiler/InterpreterLoopsJoinCtxs.ml
@@ -554,9 +554,15 @@ let join_ctxs (loop_id : V.LoopId.id) (fixed_ids : ids_sets) (ctx0 : C.eval_ctx)
C.type_context;
fun_context;
global_context;
+ trait_decls_context;
+ trait_impls_context;
region_groups;
type_vars;
const_generic_vars;
+ const_generic_vars_map;
+ norm_trait_etypes;
+ norm_trait_rtypes;
+ norm_trait_stypes;
env = _;
ended_regions = ended_regions0;
} =
@@ -566,9 +572,15 @@ let join_ctxs (loop_id : V.LoopId.id) (fixed_ids : ids_sets) (ctx0 : C.eval_ctx)
C.type_context = _;
fun_context = _;
global_context = _;
+ trait_decls_context = _;
+ trait_impls_context = _;
region_groups = _;
type_vars = _;
const_generic_vars = _;
+ const_generic_vars_map = _;
+ norm_trait_etypes = _;
+ norm_trait_rtypes = _;
+ norm_trait_stypes = _;
env = _;
ended_regions = ended_regions1;
} =
@@ -580,9 +592,15 @@ let join_ctxs (loop_id : V.LoopId.id) (fixed_ids : ids_sets) (ctx0 : C.eval_ctx)
C.type_context;
fun_context;
global_context;
+ trait_decls_context;
+ trait_impls_context;
region_groups;
type_vars;
const_generic_vars;
+ const_generic_vars_map;
+ norm_trait_etypes;
+ norm_trait_rtypes;
+ norm_trait_stypes;
env;
ended_regions;
}
diff --git a/compiler/InterpreterLoopsMatchCtxs.ml b/compiler/InterpreterLoopsMatchCtxs.ml
index 9248e513..8cab546e 100644
--- a/compiler/InterpreterLoopsMatchCtxs.ml
+++ b/compiler/InterpreterLoopsMatchCtxs.ml
@@ -149,20 +149,25 @@ let rec match_types (match_distinct_types : 'r T.ty -> 'r T.ty -> 'r T.ty)
(match_regions : 'r -> 'r -> 'r) (ty0 : 'r T.ty) (ty1 : 'r T.ty) : 'r T.ty =
let match_rec = match_types match_distinct_types match_regions in
match (ty0, ty1) with
- | Adt (id0, regions0, tys0, cgs0), Adt (id1, regions1, tys1, cgs1) ->
+ | Adt (id0, generics0), Adt (id1, generics1) ->
assert (id0 = id1);
- assert (cgs0 = cgs1);
+ assert (generics0.const_generics = generics1.const_generics);
+ assert (generics0.trait_refs = generics1.trait_refs);
let id = id0 in
- let cgs = cgs1 in
+ let const_generics = generics1.const_generics in
+ let trait_refs = generics1.trait_refs in
let regions =
List.map
(fun (id0, id1) -> match_regions id0 id1)
- (List.combine regions0 regions1)
+ (List.combine generics0.regions generics1.regions)
in
- let tys =
- List.map (fun (ty0, ty1) -> match_rec ty0 ty1) (List.combine tys0 tys1)
+ let types =
+ List.map
+ (fun (ty0, ty1) -> match_rec ty0 ty1)
+ (List.combine generics0.types generics1.types)
in
- Adt (id, regions, tys, cgs)
+ let generics = { T.regions; types; const_generics; trait_refs } in
+ Adt (id, generics)
| TypeVar vid0, TypeVar vid1 ->
assert (vid0 = vid1);
let vid = vid0 in
diff --git a/compiler/InterpreterPaths.ml b/compiler/InterpreterPaths.ml
index 04dc8892..2a277c91 100644
--- a/compiler/InterpreterPaths.ml
+++ b/compiler/InterpreterPaths.ml
@@ -3,6 +3,7 @@ module V = Values
module E = Expressions
module C = Contexts
module Subst = Substitute
+module Assoc = AssociatedTypes
module L = Logging
open Cps
open ValuesUtils
@@ -95,16 +96,14 @@ let rec access_projection (access : projection_access) (ctx : C.eval_ctx)
| pe :: p' -> (
(* Match on the projection element and the value *)
match (pe, v.V.value, v.V.ty) with
- | ( Field (((ProjAdt (_, _) | ProjOption _) as proj_kind), field_id),
+ | ( Field ((ProjAdt (_, _) as proj_kind), field_id),
V.Adt adt,
- T.Adt (type_id, _, _, _) ) -> (
+ T.Adt (type_id, _) ) -> (
(* Check consistency *)
(match (proj_kind, type_id) with
| ProjAdt (def_id, opt_variant_id), T.AdtId def_id' ->
assert (def_id = def_id');
assert (opt_variant_id = adt.variant_id)
- | ProjOption variant_id, T.Assumed T.Option ->
- assert (Some variant_id = adt.variant_id)
| _ -> raise (Failure "Unreachable"));
(* Actually project *)
let fv = T.FieldId.nth adt.field_values field_id in
@@ -119,8 +118,7 @@ let rec access_projection (access : projection_access) (ctx : C.eval_ctx)
let updated = { v with value = nadt } in
Ok (ctx, { res with updated }))
(* Tuples *)
- | Field (ProjTuple arity, field_id), V.Adt adt, T.Adt (T.Tuple, _, _, _)
- -> (
+ | Field (ProjTuple arity, field_id), V.Adt adt, T.Adt (T.Tuple, _) -> (
assert (arity = List.length adt.field_values);
let fv = T.FieldId.nth adt.field_values field_id in
(* Project *)
@@ -136,7 +134,7 @@ let rec access_projection (access : projection_access) (ctx : C.eval_ctx)
Ok (ctx, { res with updated })
(* If we reach Bottom, it may mean we need to expand an uninitialized
* enumeration value *))
- | Field ((ProjAdt (_, _) | ProjTuple _ | ProjOption _), _), V.Bottom, _ ->
+ | Field ((ProjAdt (_, _) | ProjTuple _), _), V.Bottom, _ ->
Error (FailBottom (1 + List.length p', pe, v.ty))
(* Symbolic value: needs to be expanded *)
| _, Symbolic sp, _ ->
@@ -145,9 +143,9 @@ let rec access_projection (access : projection_access) (ctx : C.eval_ctx)
(* Box dereferencement *)
| ( DerefBox,
Adt { variant_id = None; field_values = [ bv ] },
- T.Adt (T.Assumed T.Box, _, _, _) ) -> (
- (* We allow moving inside of boxes. In practice, this kind of
- * manipulations should happen only inside unsage code, so
+ T.Adt (T.Assumed T.Box, _) ) -> (
+ (* We allow moving outside of boxes. In practice, this kind of
+ * manipulations should happen only inside unsafe code, so
* it shouldn't happen due to user code, and we leverage it
* when implementing box dereferencement for the concrete
* interpreter *)
@@ -357,45 +355,32 @@ let write_place (access : access_kind) (p : E.place) (nv : V.typed_value)
| Error e -> raise (Failure ("Unreachable: " ^ show_path_fail_kind e))
| Ok ctx -> ctx
-let compute_expanded_bottom_adt_value (tyctx : T.type_decl T.TypeDeclId.Map.t)
+let compute_expanded_bottom_adt_value (ctx : C.eval_ctx)
(def_id : T.TypeDeclId.id) (opt_variant_id : T.VariantId.id option)
- (regions : T.erased_region list) (types : T.ety list)
- (cgs : T.const_generic list) : V.typed_value =
+ (generics : T.egeneric_args) : V.typed_value =
(* Lookup the definition and check if it is an enumeration - it
should be an enumeration if and only if the projection element
is a field projection with *some* variant id. Retrieve the list
of fields at the same time. *)
- let def = T.TypeDeclId.Map.find def_id tyctx in
- assert (List.length regions = List.length def.T.region_params);
+ let def = C.ctx_lookup_type_decl ctx def_id in
+ assert (List.length generics.regions = List.length def.T.generics.regions);
(* Compute the field types *)
let field_types =
- Subst.type_decl_get_instantiated_field_etypes def opt_variant_id types cgs
+ Assoc.type_decl_get_inst_norm_field_etypes ctx def opt_variant_id generics
in
(* Initialize the expanded value *)
let fields = List.map mk_bottom field_types in
let av = V.Adt { variant_id = opt_variant_id; field_values = fields } in
- let ty = T.Adt (T.AdtId def_id, regions, types, cgs) in
+ let ty = T.Adt (T.AdtId def_id, generics) in
{ V.value = av; V.ty }
-let compute_expanded_bottom_option_value (variant_id : T.VariantId.id)
- (param_ty : T.ety) : V.typed_value =
- (* Note that the variant can be [Some] or [None]: we expand bottom values
- * when writing to fields or setting discriminants *)
- let field_values =
- if variant_id = T.option_some_id then [ mk_bottom param_ty ]
- else if variant_id = T.option_none_id then []
- else raise (Failure "Unreachable")
- in
- let av = V.Adt { variant_id = Some variant_id; field_values } in
- let ty = T.Adt (T.Assumed T.Option, [], [ param_ty ], []) in
- { V.value = av; ty }
-
let compute_expanded_bottom_tuple_value (field_types : T.ety list) :
V.typed_value =
(* Generate the field values *)
let fields = List.map mk_bottom field_types in
let v = V.Adt { variant_id = None; field_values = fields } in
- let ty = T.Adt (T.Tuple, [], field_types, []) in
+ let generics = TypesUtils.mk_generic_args [] field_types [] [] in
+ let ty = T.Adt (T.Tuple, generics) in
{ V.value = v; V.ty }
(** Auxiliary helper to expand {!V.Bottom} values.
@@ -447,19 +432,18 @@ let expand_bottom_value_from_projection (access : access_kind) (p : E.place)
match (pe, ty) with
(* "Regular" ADTs *)
| ( Field (ProjAdt (def_id, opt_variant_id), _),
- T.Adt (T.AdtId def_id', regions, types, cgs) ) ->
+ T.Adt (T.AdtId def_id', generics) ) ->
assert (def_id = def_id');
- compute_expanded_bottom_adt_value ctx.type_context.type_decls def_id
- opt_variant_id regions types cgs
- (* Option *)
- | ( Field (ProjOption variant_id, _),
- T.Adt (T.Assumed T.Option, [], [ ty ], []) ) ->
- compute_expanded_bottom_option_value variant_id ty
+ compute_expanded_bottom_adt_value ctx def_id opt_variant_id generics
(* Tuples *)
- | Field (ProjTuple arity, _), T.Adt (T.Tuple, [], tys, []) ->
- assert (arity = List.length tys);
+ | ( Field (ProjTuple arity, _),
+ T.Adt
+ ( T.Tuple,
+ { T.regions = []; types; const_generics = []; trait_refs = [] } ) )
+ ->
+ assert (arity = List.length types);
(* Generate the field values *)
- compute_expanded_bottom_tuple_value tys
+ compute_expanded_bottom_tuple_value types
| _ ->
raise
(Failure
diff --git a/compiler/InterpreterPaths.mli b/compiler/InterpreterPaths.mli
index 4a9f3b41..0ff8063f 100644
--- a/compiler/InterpreterPaths.mli
+++ b/compiler/InterpreterPaths.mli
@@ -3,6 +3,7 @@ module V = Values
module E = Expressions
module C = Contexts
module Subst = Substitute
+module Assoc = AssociatedTypes
module L = Logging
open Cps
open InterpreterExpansion
@@ -56,18 +57,12 @@ val compute_expanded_bottom_tuple_value : T.ety list -> V.typed_value
(** Compute an expanded ADT ⊥ value *)
val compute_expanded_bottom_adt_value :
- T.type_decl T.TypeDeclId.Map.t ->
+ C.eval_ctx ->
T.TypeDeclId.id ->
T.VariantId.id option ->
- T.erased_region list ->
- T.ety list ->
- T.const_generic list ->
+ T.egeneric_args ->
V.typed_value
-(** Compute an expanded [Option] ⊥ value *)
-val compute_expanded_bottom_option_value :
- T.VariantId.id -> T.ety -> V.typed_value
-
(** Drop (end) outer loans at a given place, which should be seen as an l-value
(we will write to it later, but need to drop the loans before writing).
diff --git a/compiler/InterpreterProjectors.ml b/compiler/InterpreterProjectors.ml
index faed066b..9e0c2b75 100644
--- a/compiler/InterpreterProjectors.ml
+++ b/compiler/InterpreterProjectors.ml
@@ -3,6 +3,7 @@ module V = Values
module E = Expressions
module C = Contexts
module Subst = Substitute
+module Assoc = AssociatedTypes
module L = Logging
open TypesUtils
open InterpreterUtils
@@ -24,12 +25,12 @@ let rec apply_proj_borrows_on_shared_borrow (ctx : C.eval_ctx)
else
match (v.V.value, ty) with
| V.Literal _, T.Literal _ -> []
- | V.Adt adt, T.Adt (id, region_params, tys, cgs) ->
+ | V.Adt adt, T.Adt (id, generics) ->
(* Retrieve the types of the fields *)
let field_types =
- Subst.ctx_adt_value_get_instantiated_field_rtypes ctx adt id
- region_params tys cgs
+ Assoc.ctx_adt_value_get_inst_norm_field_rtypes ctx adt id generics
in
+
(* Project over the field values *)
let fields_types = List.combine adt.V.field_values field_types in
let proj_fields =
@@ -103,11 +104,10 @@ let rec apply_proj_borrows (check_symbolic_no_ended : bool) (ctx : C.eval_ctx)
let value : V.avalue =
match (v.V.value, ty) with
| V.Literal _, T.Literal _ -> V.AIgnored
- | V.Adt adt, T.Adt (id, region_params, tys, cgs) ->
+ | V.Adt adt, T.Adt (id, generics) ->
(* Retrieve the types of the fields *)
let field_types =
- Subst.ctx_adt_value_get_instantiated_field_rtypes ctx adt id
- region_params tys cgs
+ Assoc.ctx_adt_value_get_inst_norm_field_rtypes ctx adt id generics
in
(* Project over the field values *)
let fields_types = List.combine adt.V.field_values field_types in
@@ -268,8 +268,7 @@ let apply_proj_loans_on_symbolic_expansion (regions : T.RegionId.Set.t)
let (value, ty) : V.avalue * T.rty =
match (see, original_sv_ty) with
| SeLiteral _, T.Literal _ -> (V.AIgnored, original_sv_ty)
- | SeAdt (variant_id, field_values), T.Adt (_id, _region_params, _tys, _cgs)
- ->
+ | SeAdt (variant_id, field_values), T.Adt (_id, _generics) ->
(* Project over the field values *)
let field_values =
List.map
diff --git a/compiler/InterpreterStatements.ml b/compiler/InterpreterStatements.ml
index 045c4484..e0c4703b 100644
--- a/compiler/InterpreterStatements.ml
+++ b/compiler/InterpreterStatements.ml
@@ -10,13 +10,13 @@ open TypesUtils
open ValuesUtils
module Inv = Invariants
module S = SynthesizeSymbolic
-open Utils
open Cps
open InterpreterUtils
open InterpreterProjectors
open InterpreterExpansion
open InterpreterPaths
open InterpreterExpressions
+module PCtx = Print.EvalCtxLlbcAst
(** The local logger *)
let log = L.statements_log
@@ -232,9 +232,7 @@ let set_discriminant (config : C.config) (p : E.place)
let update_value cf (v : V.typed_value) : m_fun =
fun ctx ->
match (v.V.ty, v.V.value) with
- | ( T.Adt
- (((T.AdtId _ | T.Assumed T.Option) as type_id), regions, types, cgs),
- V.Adt av ) -> (
+ | T.Adt ((T.AdtId _ as type_id), generics), V.Adt av -> (
(* There are two situations:
- either the discriminant is already the proper one (in which case we
don't do anything)
@@ -251,28 +249,17 @@ let set_discriminant (config : C.config) (p : E.place)
let bottom_v =
match type_id with
| T.AdtId def_id ->
- compute_expanded_bottom_adt_value
- ctx.type_context.type_decls def_id (Some variant_id)
- regions types cgs
- | T.Assumed T.Option ->
- assert (regions = []);
- compute_expanded_bottom_option_value variant_id
- (Collections.List.to_cons_nil types)
+ compute_expanded_bottom_adt_value ctx def_id
+ (Some variant_id) generics
| _ -> raise (Failure "Unreachable")
in
assign_to_place config bottom_v p (cf Unit) ctx)
- | ( T.Adt
- (((T.AdtId _ | T.Assumed T.Option) as type_id), regions, types, cgs),
- V.Bottom ) ->
+ | T.Adt ((T.AdtId _ as type_id), generics), V.Bottom ->
let bottom_v =
match type_id with
| T.AdtId def_id ->
- compute_expanded_bottom_adt_value ctx.type_context.type_decls
- def_id (Some variant_id) regions types cgs
- | T.Assumed T.Option ->
- assert (regions = []);
- compute_expanded_bottom_option_value variant_id
- (Collections.List.to_cons_nil types)
+ compute_expanded_bottom_adt_value ctx def_id (Some variant_id)
+ generics
| _ -> raise (Failure "Unreachable")
in
assign_to_place config bottom_v p (cf Unit) ctx
@@ -301,24 +288,34 @@ let ctx_push_frame (ctx : C.eval_ctx) : C.eval_ctx =
let push_frame : cm_fun = fun cf ctx -> cf (ctx_push_frame ctx)
(** Small helper: compute the type of the return value for a specific
- instantiation of a non-local function.
+ instantiation of an assumed function.
*)
-let get_non_local_function_return_type (fid : A.assumed_fun_id)
- (region_params : T.erased_region list) (type_params : T.ety list)
- (const_generic_params : T.const_generic list) : T.ety =
+let get_assumed_function_return_type (ctx : C.eval_ctx) (fid : A.assumed_fun_id)
+ (generics : T.egeneric_args) : T.ety =
+ assert (generics.trait_refs = []);
(* [Box::free] has a special treatment *)
- match (fid, region_params, type_params, const_generic_params) with
- | A.BoxFree, [], [ _ ], [] -> mk_unit_ty
+ match fid with
+ | BoxFree ->
+ assert (generics.regions = []);
+ assert (List.length generics.types = 1);
+ assert (generics.const_generics = []);
+ mk_unit_ty
| _ ->
(* Retrieve the function's signature *)
- let sg = Assumed.get_assumed_sig fid in
+ let sg = Assumed.get_assumed_fun_sig fid in
(* Instantiate the return type *)
- let tsubst = Subst.make_type_subst_from_vars sg.type_params type_params in
- let cgsubst =
- Subst.make_const_generic_subst_from_vars sg.const_generic_params
- const_generic_params
+ (* There shouldn't be any reference to Self *)
+ let tr_self : T.erased_region T.trait_instance_id =
+ T.UnknownTrait __FUNCTION__
+ in
+ let { Subst.r_subst = _; ty_subst; cg_subst; tr_subst; tr_self } =
+ Subst.make_esubst_from_generics sg.generics generics tr_self
+ in
+ let ty =
+ Subst.erase_regions_substitute_types ty_subst cg_subst tr_subst tr_self
+ sg.output
in
- Subst.erase_regions_substitute_types tsubst cgsubst sg.output
+ Assoc.ctx_normalize_ety ctx ty
let move_return_value (config : C.config) (pop_return_value : bool)
(cf : V.typed_value option -> m_fun) : m_fun =
@@ -418,19 +415,14 @@ let pop_frame_assign (config : C.config) (dest : E.place) : cm_fun =
in
comp cf_pop cf_assign
-(** Auxiliary function - see {!eval_non_local_function_call} *)
-let eval_replace_concrete (_config : C.config)
- (_region_params : T.erased_region list) (_type_params : T.ety list)
- (_cg_params : T.const_generic list) : cm_fun =
- fun _cf _ctx -> raise Unimplemented
-
-(** Auxiliary function - see {!eval_non_local_function_call} *)
-let eval_box_new_concrete (config : C.config)
- (region_params : T.erased_region list) (type_params : T.ety list)
- (cg_params : T.const_generic list) : cm_fun =
+(** Auxiliary function - see {!eval_assumed_function_call} *)
+let eval_box_new_concrete (config : C.config) (generics : T.egeneric_args) :
+ cm_fun =
fun cf ctx ->
(* Check and retrieve the arguments *)
- match (region_params, type_params, cg_params, ctx.env) with
+ match
+ (generics.regions, generics.types, generics.const_generics, ctx.env)
+ with
| ( [],
[ boxed_ty ],
[],
@@ -448,7 +440,8 @@ let eval_box_new_concrete (config : C.config)
(* Create the new box *)
let cf_create cf (moved_input_value : V.typed_value) : m_fun =
(* Create the box value *)
- let box_ty = T.Adt (T.Assumed T.Box, [], [ boxed_ty ], []) in
+ let generics = TypesUtils.mk_generic_args_from_types [ boxed_ty ] in
+ let box_ty = T.Adt (T.Assumed T.Box, generics) in
let box_v =
V.Adt { variant_id = None; field_values = [ moved_input_value ] }
in
@@ -466,71 +459,7 @@ let eval_box_new_concrete (config : C.config)
comp cf_move cf_create cf ctx
| _ -> raise (Failure "Inconsistent state")
-(** Auxiliary function which factorizes code to evaluate [std::Deref::deref]
- and [std::DerefMut::deref_mut] - see {!eval_non_local_function_call} *)
-let eval_box_deref_mut_or_shared_concrete (config : C.config)
- (region_params : T.erased_region list) (type_params : T.ety list)
- (cg_params : T.const_generic list) (is_mut : bool) : cm_fun =
- fun cf ctx ->
- (* Check the arguments *)
- match (region_params, type_params, cg_params, ctx.env) with
- | ( [],
- [ boxed_ty ],
- [],
- Var (VarBinder input_var, input_value)
- :: Var (_ret_var, _)
- :: C.Frame :: _ ) ->
- (* Required type checking. We must have:
- - input_value.ty = & (mut) Box<ty>
- - boxed_ty = ty
- for some ty
- *)
- (let _, input_ty, ref_kind = ty_get_ref input_value.V.ty in
- assert (match ref_kind with T.Shared -> not is_mut | T.Mut -> is_mut);
- let input_ty = ty_get_box input_ty in
- assert (input_ty = boxed_ty));
-
- (* Borrow the boxed value *)
- let p =
- { E.var_id = input_var.C.index; projection = [ E.Deref; E.DerefBox ] }
- in
- let borrow_kind = if is_mut then E.Mut else E.Shared in
- let rv = E.Ref (p, borrow_kind) in
- let cf_borrow = eval_rvalue_not_global config rv in
-
- (* Move the borrow to its destination *)
- let cf_move cf res : m_fun =
- match res with
- | Error EPanic ->
- (* We can't get there by borrowing a value *)
- raise (Failure "Unreachable")
- | Ok borrowed_value ->
- (* Move and continue *)
- let destp = mk_place_from_var_id E.VarId.zero in
- assign_to_place config borrowed_value destp cf
- in
-
- (* Compose and apply *)
- comp cf_borrow cf_move cf ctx
- | _ -> raise (Failure "Inconsistent state")
-
-(** Auxiliary function - see {!eval_non_local_function_call} *)
-let eval_box_deref_concrete (config : C.config)
- (region_params : T.erased_region list) (type_params : T.ety list)
- (cg_params : T.const_generic list) : cm_fun =
- let is_mut = false in
- eval_box_deref_mut_or_shared_concrete config region_params type_params
- cg_params is_mut
-
-(** Auxiliary function - see {!eval_non_local_function_call} *)
-let eval_box_deref_mut_concrete (config : C.config)
- (region_params : T.erased_region list) (type_params : T.ety list)
- (cg_params : T.const_generic list) : cm_fun =
- let is_mut = true in
- eval_box_deref_mut_or_shared_concrete config region_params type_params
- cg_params is_mut
-
-(** Auxiliary function - see {!eval_non_local_function_call}.
+(** Auxiliary function - see {!eval_assumed_function_call}.
[Box::free] is not handled the same way as the other assumed functions:
- in the regular case, whenever we need to evaluate an assumed function,
@@ -549,11 +478,10 @@ let eval_box_deref_mut_concrete (config : C.config)
It thus updates the box value (by calling {!drop_value}) and updates
the destination (by setting it to [()]).
*)
-let eval_box_free (config : C.config) (region_params : T.erased_region list)
- (type_params : T.ety list) (cg_params : T.const_generic list)
+let eval_box_free (config : C.config) (generics : T.egeneric_args)
(args : E.operand list) (dest : E.place) : cm_fun =
fun cf ctx ->
- match (region_params, type_params, cg_params, args) with
+ match (generics.regions, generics.types, generics.const_generics, args) with
| [], [ boxed_ty ], [], [ E.Move input_box_place ] ->
(* Required type checking *)
let input_box = InterpreterPaths.read_place Write input_box_place ctx in
@@ -570,26 +498,24 @@ let eval_box_free (config : C.config) (region_params : T.erased_region list)
cc cf ctx
| _ -> raise (Failure "Inconsistent state")
-(** Auxiliary function - see {!eval_non_local_function_call} *)
-let eval_vec_function_concrete (_config : C.config) (_fid : A.assumed_fun_id)
- (_region_params : T.erased_region list) (_type_params : T.ety list)
- (_cg_params : T.const_generic list) : cm_fun =
- fun _cf _ctx -> raise Unimplemented
-
(** Evaluate a non-local function call in concrete mode *)
-let eval_non_local_function_call_concrete (config : C.config)
- (fid : A.assumed_fun_id) (region_params : T.erased_region list)
- (type_params : T.ety list) (cg_params : T.const_generic list)
- (args : E.operand list) (dest : E.place) : cm_fun =
+let eval_assumed_function_call_concrete (config : C.config)
+ (fid : A.assumed_fun_id) (call : A.call) : cm_fun =
+ let generics = call.func.generics in
+ let args = call.args in
+ let dest = call.dest in
+ (* Sanity check: we don't fully handle the const generic vars environment
+ in concrete mode yet *)
+ assert (generics.const_generics = []);
(* There are two cases (and this is extremely annoying):
- the function is not box_free
- the function is box_free
See {!eval_box_free}
*)
match fid with
- | A.BoxFree ->
+ | BoxFree ->
(* Degenerate case: box_free *)
- eval_box_free config region_params type_params cg_params args dest
+ eval_box_free config generics args dest
| _ ->
(* "Normal" case: not box_free *)
(* Evaluate the operands *)
@@ -604,16 +530,14 @@ let eval_non_local_function_call_concrete (config : C.config)
* but it made it less clear where the computed values came from,
* so we reversed the modifications. *)
let cf_eval_call cf (args_vl : V.typed_value list) : m_fun =
+ fun ctx ->
(* Push the stack frame: we initialize the frame with the return variable,
and one variable per input argument *)
let cc = push_frame in
(* Create and push the return variable *)
let ret_vid = E.VarId.zero in
- let ret_ty =
- get_non_local_function_return_type fid region_params type_params
- cg_params
- in
+ let ret_ty = get_assumed_function_return_type ctx fid generics in
let ret_var = mk_var ret_vid (Some "@return") ret_ty in
let cc = comp cc (push_uninitialized_var ret_var) in
@@ -630,24 +554,12 @@ let eval_non_local_function_call_concrete (config : C.config)
* access to a body. *)
let cf_eval_body : cm_fun =
match fid with
- | A.Replace ->
- eval_replace_concrete config region_params type_params cg_params
- | BoxNew ->
- eval_box_new_concrete config region_params type_params cg_params
- | BoxDeref ->
- eval_box_deref_concrete config region_params type_params cg_params
- | BoxDerefMut ->
- eval_box_deref_mut_concrete config region_params type_params
- cg_params
+ | BoxNew -> eval_box_new_concrete config generics
| BoxFree ->
(* Should have been treated above *) raise (Failure "Unreachable")
- | VecNew | VecPush | VecInsert | VecLen | VecIndex | VecIndexMut ->
- eval_vec_function_concrete config fid region_params type_params
- cg_params
| ArrayIndexShared | ArrayIndexMut | ArrayToSliceShared
- | ArrayToSliceMut | ArraySubsliceShared | ArraySubsliceMut
- | SliceIndexShared | SliceIndexMut | SliceSubsliceShared
- | SliceSubsliceMut | SliceLen ->
+ | ArrayToSliceMut | ArrayRepeat | SliceIndexShared | SliceIndexMut
+ | SliceLen ->
raise (Failure "Unimplemented")
in
@@ -657,50 +569,11 @@ let eval_non_local_function_call_concrete (config : C.config)
let cc = comp cc (pop_frame_assign config dest) in
(* Continue *)
- cc cf
+ cc cf ctx
in
(* Compose and apply *)
comp cf_eval_ops cf_eval_call
-let instantiate_fun_sig (type_params : T.ety list)
- (cg_params : T.const_generic list) (sg : A.fun_sig) : A.inst_fun_sig =
- (* Generate fresh abstraction ids and create a substitution from region
- * group ids to abstraction ids *)
- let rg_abs_ids_bindings =
- List.map
- (fun rg ->
- let abs_id = C.fresh_abstraction_id () in
- (rg.T.id, abs_id))
- sg.regions_hierarchy
- in
- let asubst_map : V.AbstractionId.id T.RegionGroupId.Map.t =
- List.fold_left
- (fun mp (rg_id, abs_id) -> T.RegionGroupId.Map.add rg_id abs_id mp)
- T.RegionGroupId.Map.empty rg_abs_ids_bindings
- in
- let asubst (rg_id : T.RegionGroupId.id) : V.AbstractionId.id =
- T.RegionGroupId.Map.find rg_id asubst_map
- in
- (* Generate fresh regions and their substitutions *)
- let _, rsubst, _ = Subst.fresh_regions_with_substs sg.region_params in
- (* Generate the type substitution
- * Note that we need the substitution to map the type variables to
- * {!rty} types (not {!ety}). In order to do that, we convert the
- * type parameters to types with regions. This is possible only
- * if those types don't contain any regions.
- * This is a current limitation of the analysis: there is still some
- * work to do to properly handle full type parametrization.
- * *)
- let rtype_params = List.map ety_no_regions_to_rty type_params in
- let tsubst = Subst.make_type_subst_from_vars sg.type_params rtype_params in
- let cgsubst =
- Subst.make_const_generic_subst_from_vars sg.const_generic_params cg_params
- in
- (* Substitute the signature *)
- let inst_sig = Subst.substitute_signature asubst rsubst tsubst cgsubst sg in
- (* Return *)
- inst_sig
-
(** Helper
Create abstractions (with no avalues, which have to be inserted afterwards)
@@ -836,7 +709,7 @@ let rec eval_statement (config : C.config) (st : A.statement) : st_cm_fun =
match rvalue with
| E.Global _ -> raise (Failure "Unreachable")
| E.Use _
- | E.Ref (_, (E.Shared | E.Mut | E.TwoPhaseMut | E.Shallow))
+ | E.RvRef (_, (E.Shared | E.Mut | E.TwoPhaseMut | E.Shallow))
| E.UnaryOp _ | E.BinaryOp _ | E.Discriminant _
| E.Aggregate _ ->
let rp = rvalue_get_place rvalue in
@@ -893,7 +766,15 @@ and eval_global (config : C.config) (dest : E.place) (gid : LA.GlobalDeclId.id)
match config.mode with
| ConcreteMode ->
(* Treat the evaluation of the global as a call to the global body (without arguments) *)
- (eval_local_function_call_concrete config global.body_id [] [] [] [] dest)
+ let func =
+ {
+ E.func = FunId (Regular global.body_id);
+ generics = TypesUtils.mk_empty_generic_args;
+ trait_and_method_generic_args = None;
+ }
+ in
+ let call = { A.func; args = []; dest } in
+ (eval_transparent_function_call_concrete config global.body_id call)
cf ctx
| SymbolicMode ->
(* Generate a fresh symbolic value. In the translation, this fresh symbolic value will be
@@ -1037,128 +918,374 @@ and eval_switch (config : C.config) (switch : A.switch) : st_cm_fun =
(** Evaluate a function call (auxiliary helper for [eval_statement]) *)
and eval_function_call (config : C.config) (call : A.call) : st_cm_fun =
- (* There are two cases:
+ (* There are several cases:
- this is a local function, in which case we execute its body
- - this is a non-local function, in which case there is a special treatment
+ - this is an assumed function, in which case there is a special treatment
+ - this is a trait method
*)
- match call.func with
- | A.Regular fid ->
- eval_local_function_call config fid call.region_args call.type_args
- call.const_generic_args call.args call.dest
- | A.Assumed fid ->
- eval_non_local_function_call config fid call.region_args call.type_args
- call.const_generic_args call.args call.dest
+ match config.mode with
+ | C.ConcreteMode -> eval_function_call_concrete config call
+ | C.SymbolicMode -> eval_function_call_symbolic config call
-(** Evaluate a local (i.e., non-assumed) function call in concrete mode *)
-and eval_local_function_call_concrete (config : C.config) (fid : A.FunDeclId.id)
- (_region_args : T.erased_region list) (type_args : T.ety list)
- (cg_args : T.const_generic list) (args : E.operand list) (dest : E.place) :
- st_cm_fun =
+and eval_function_call_concrete (config : C.config) (call : A.call) : st_cm_fun
+ =
fun cf ctx ->
- (* Retrieve the (correctly instantiated) body *)
- let def = C.ctx_lookup_fun_decl ctx fid in
- (* We can evaluate the function call only if it is not opaque *)
- let body =
- match def.body with
- | None ->
- raise
- (Failure
- ("Can't evaluate a call to an opaque function: "
- ^ Print.name_to_string def.name))
- | Some body -> body
- in
- let tsubst =
- Subst.make_type_subst_from_vars def.A.signature.type_params type_args
- in
- let cgsubst =
- Subst.make_const_generic_subst_from_vars
- def.A.signature.const_generic_params cg_args
- in
- let locals, body_st = Subst.fun_body_substitute_in_body tsubst cgsubst body in
-
- (* Evaluate the input operands *)
- assert (List.length args = body.A.arg_count);
- let cc = eval_operands config args in
-
- (* Push a frame delimiter - we use {!comp_transmit} to transmit the result
- * of the operands evaluation from above to the functions afterwards, while
- * ignoring it in this function *)
- let cc = comp_transmit cc push_frame in
-
- (* Compute the initial values for the local variables *)
- (* 1. Push the return value *)
- let ret_var, locals =
- match locals with
- | ret_ty :: locals -> (ret_ty, locals)
- | _ -> raise (Failure "Unreachable")
- in
- let input_locals, locals =
- Collections.List.split_at locals body.A.arg_count
- in
+ match call.func.func with
+ | FunId (Regular fid) ->
+ eval_transparent_function_call_concrete config fid call cf ctx
+ | FunId (Assumed fid) ->
+ (* Continue - note that we do as if the function call has been successful,
+ * by giving {!Unit} to the continuation, because we place us in the case
+ * where we haven't panicked. Of course, the translation needs to take the
+ * panic case into account... *)
+ eval_assumed_function_call_concrete config fid call (cf Unit) ctx
+ | TraitMethod _ -> raise (Failure "Unimplemented")
+
+and eval_function_call_symbolic (config : C.config) (call : A.call) : st_cm_fun
+ =
+ match call.func.func with
+ | FunId (Regular _) | TraitMethod _ ->
+ eval_transparent_function_call_symbolic config call
+ | FunId (Assumed fid) -> eval_assumed_function_call_symbolic config fid call
- let cc = comp_transmit cc (push_var ret_var (mk_bottom ret_var.var_ty)) in
-
- (* 2. Push the input values *)
- let cf_push_inputs cf args =
- let inputs = List.combine input_locals args in
- (* Note that this function checks that the variables and their values
- * have the same type (this is important) *)
- push_vars inputs cf
- in
- let cc = comp cc cf_push_inputs in
-
- (* 3. Push the remaining local variables (initialized as {!Bottom}) *)
- let cc = comp cc (push_uninitialized_vars locals) in
+(** Evaluate a local (i.e., non-assumed) function call in concrete mode *)
+and eval_transparent_function_call_concrete (config : C.config)
+ (fid : A.FunDeclId.id) (call : A.call) : st_cm_fun =
+ let generics = call.func.generics in
+ let args = call.A.args in
+ let dest = call.A.dest in
+ (* Sanity check: we don't fully handle the const generic vars environment
+ in concrete mode yet *)
+ assert (generics.const_generics = []);
+ fun cf ctx ->
+ (* Retrieve the (correctly instantiated) body *)
+ let def = C.ctx_lookup_fun_decl ctx fid in
+ (* We can evaluate the function call only if it is not opaque *)
+ let body =
+ match def.body with
+ | None ->
+ raise
+ (Failure
+ ("Can't evaluate a call to an opaque function: "
+ ^ Print.name_to_string def.name))
+ | Some body -> body
+ in
+ (* TODO: we need to normalize the types if we want to correctly support traits *)
+ assert (generics.trait_refs = []);
+ (* There shouldn't be any reference to Self *)
+ let tr_self = T.UnknownTrait __FUNCTION__ in
+ let subst =
+ Subst.make_esubst_from_generics def.A.signature.generics generics tr_self
+ in
+ let locals, body_st = Subst.fun_body_substitute_in_body subst body in
+
+ (* Evaluate the input operands *)
+ assert (List.length args = body.A.arg_count);
+ let cc = eval_operands config args in
+
+ (* Push a frame delimiter - we use {!comp_transmit} to transmit the result
+ * of the operands evaluation from above to the functions afterwards, while
+ * ignoring it in this function *)
+ let cc = comp_transmit cc push_frame in
+
+ (* Compute the initial values for the local variables *)
+ (* 1. Push the return value *)
+ let ret_var, locals =
+ match locals with
+ | ret_ty :: locals -> (ret_ty, locals)
+ | _ -> raise (Failure "Unreachable")
+ in
+ let input_locals, locals =
+ Collections.List.split_at locals body.A.arg_count
+ in
- (* Execute the function body *)
- let cc = comp cc (eval_function_body config body_st) in
+ let cc = comp_transmit cc (push_var ret_var (mk_bottom ret_var.var_ty)) in
- (* Pop the stack frame and move the return value to its destination *)
- let cf_finish cf res =
- match res with
- | Panic -> cf Panic
- | Return ->
- (* Pop the stack frame, retrieve the return value, move it to
- * its destination and continue *)
- pop_frame_assign config dest (cf Unit)
- | Break _ | Continue _ | Unit | LoopReturn _ | EndEnterLoop _
- | EndContinue _ ->
- raise (Failure "Unreachable")
- in
- let cc = comp cc cf_finish in
+ (* 2. Push the input values *)
+ let cf_push_inputs cf args =
+ let inputs = List.combine input_locals args in
+ (* Note that this function checks that the variables and their values
+ * have the same type (this is important) *)
+ push_vars inputs cf
+ in
+ let cc = comp cc cf_push_inputs in
+
+ (* 3. Push the remaining local variables (initialized as {!Bottom}) *)
+ let cc = comp cc (push_uninitialized_vars locals) in
+
+ (* Execute the function body *)
+ let cc = comp cc (eval_function_body config body_st) in
+
+ (* Pop the stack frame and move the return value to its destination *)
+ let cf_finish cf res =
+ match res with
+ | Panic -> cf Panic
+ | Return ->
+ (* Pop the stack frame, retrieve the return value, move it to
+ * its destination and continue *)
+ pop_frame_assign config dest (cf Unit)
+ | Break _ | Continue _ | Unit | LoopReturn _ | EndEnterLoop _
+ | EndContinue _ ->
+ raise (Failure "Unreachable")
+ in
+ let cc = comp cc cf_finish in
- (* Continue *)
- cc cf ctx
+ (* Continue *)
+ cc cf ctx
(** Evaluate a local (i.e., non-assumed) function call in symbolic mode *)
-and eval_local_function_call_symbolic (config : C.config) (fid : A.FunDeclId.id)
- (region_args : T.erased_region list) (type_args : T.ety list)
- (cg_args : T.const_generic list) (args : E.operand list) (dest : E.place) :
- st_cm_fun =
+and eval_transparent_function_call_symbolic (config : C.config) (call : A.call)
+ : st_cm_fun =
fun cf ctx ->
- (* Retrieve the (correctly instantiated) signature *)
- let def = C.ctx_lookup_fun_decl ctx fid in
- let sg = def.A.signature in
- (* Instantiate the signature and introduce fresh abstraction and region ids
- * while doing so *)
- let inst_sg = instantiate_fun_sig type_args cg_args sg in
+ (* Instantiate the signature and introduce fresh abstractions and region ids while doing so.
+
+ We perform some manipulations when instantiating the signature.
+
+ # Trait impl calls
+ ==================
+ In particular, we have a special treatment of trait method calls when
+ the trait ref is a known impl.
+
+ For instance:
+ {[
+ trait HasValue {
+ fn has_value(&self) -> bool;
+ }
+
+ impl<T> HasValue for Option<T> {
+ fn has_value(&self) {
+ match self {
+ None => false,
+ Some(_) => true,
+ }
+ }
+ }
+
+ fn option_has_value<T>(x: &Option<T>) -> bool {
+ x.has_value()
+ }
+ ]}
+
+ The generated code looks like this:
+ {[
+ structure HasValue (Self : Type) = {
+ has_value : Self -> result bool
+ }
+
+ let OptionHasValueImpl.has_value (Self : Type) (self : Self) : result bool =
+ match self with
+ | None => false
+ | Some _ => true
+
+ let OptionHasValueInstance (T : Type) : HasValue (Option T) = {
+ has_value = OptionHasValueInstance.has_value
+ }
+ ]}
+
+ In [option_has_value], we don't want to refer to the [has_value] method
+ of the instance of [HasValue] for [Option<T>]. We want to refer directly
+ to the function which implements [has_value] for [Option<T>].
+ That is, instead of generating this:
+ {[
+ let option_has_value (T : Type) (x : Option T) : result bool =
+ (OptionHasValueInstance T).has_value x
+ ]}
+
+ We want to generate this:
+ {[
+ let option_has_value (T : Type) (x : Option T) : result bool =
+ OptionHasValueImpl.has_value T x
+ ]}
+
+ # Provided trait methods
+ ========================
+ Calls to provided trait methods also have a special treatment because
+ for now we forbid overriding provided trait methods in the trait implementations,
+ which means that whenever we call a provided trait method, we do not refer
+ to a trait clause but directly to the method provided in the trait declaration.
+ *)
+ let func, generics, def, inst_sg =
+ match call.func.func with
+ | FunId (Regular fid) ->
+ let def = C.ctx_lookup_fun_decl ctx fid in
+ log#ldebug
+ (lazy
+ ("fun call:\n- call: " ^ call_to_string ctx call
+ ^ "\n- call.generics:\n"
+ ^ egeneric_args_to_string ctx call.func.generics
+ ^ "\n- def.signature:\n"
+ ^ fun_sig_to_string ctx def.A.signature));
+ let tr_self = T.UnknownTrait __FUNCTION__ in
+ let inst_sg =
+ instantiate_fun_sig ctx call.func.generics tr_self def.A.signature
+ in
+ (call.func.func, call.func.generics, def, inst_sg)
+ | FunId (Assumed _) ->
+ (* Unreachable: must be a transparent function *)
+ raise (Failure "Unreachable")
+ | TraitMethod (trait_ref, method_name, _) -> (
+ log#ldebug
+ (lazy
+ ("trait method call:\n- call: " ^ call_to_string ctx call
+ ^ "\n- method name: " ^ method_name ^ "\n- call.generics:\n"
+ ^ egeneric_args_to_string ctx call.func.generics
+ ^ "\n- trait and method generics:\n"
+ ^ egeneric_args_to_string ctx
+ (Option.get call.func.trait_and_method_generic_args)));
+ (* When instantiating, we need to group the generics for the trait ref
+ and the method *)
+ let generics = Option.get call.func.trait_and_method_generic_args in
+ (* Lookup the trait method signature - there are several possibilities
+ depending on whethere we call a top-level trait method impl or the
+ method from a local clause *)
+ match trait_ref.trait_id with
+ | TraitImpl impl_id -> (
+ (* Lookup the trait impl *)
+ let trait_impl = C.ctx_lookup_trait_impl ctx impl_id in
+ log#ldebug
+ (lazy ("trait impl: " ^ trait_impl_to_string ctx trait_impl));
+ (* First look in the required methods *)
+ let method_id =
+ List.find_opt
+ (fun (s, _) -> s = method_name)
+ trait_impl.required_methods
+ in
+ match method_id with
+ | Some (_, id) ->
+ (* This is a required method *)
+ let method_def = C.ctx_lookup_fun_decl ctx id in
+ (* Instantiate *)
+ let tr_self =
+ T.TraitRef (etrait_ref_no_regions_to_gr_trait_ref trait_ref)
+ in
+ let inst_sg =
+ instantiate_fun_sig ctx generics tr_self
+ method_def.A.signature
+ in
+ (* Also update the function identifier: we want to forget
+ the fact that we called a trait method, and treat it as
+ a regular function call to the top-level function
+ which implements the method. In order to do this properly,
+ we also need to update the generics.
+ *)
+ let func = E.FunId (Regular id) in
+ (func, generics, method_def, inst_sg)
+ | None ->
+ (* If not found, lookup the methods provided by the trait *declaration*
+ (remember: for now, we forbid overriding provided methods) *)
+ assert (trait_impl.provided_methods = []);
+ let trait_decl =
+ C.ctx_lookup_trait_decl ctx
+ trait_ref.trait_decl_ref.trait_decl_id
+ in
+ let _, method_id =
+ List.find
+ (fun (s, _) -> s = method_name)
+ trait_decl.provided_methods
+ in
+ let method_id = Option.get method_id in
+ let method_def = C.ctx_lookup_fun_decl ctx method_id in
+ (* For the instantiation we have to do something peculiar
+ because the method was defined for the trait declaration.
+ We have to group:
+ - the parameters given to the trait decl reference
+ - the parameters given to the method itself
+ For instance:
+ {[
+ trait Foo<T> {
+ fn f<U>(...) { ... }
+ }
+
+ fn g<G>(x : G) where Clause0: Foo<G, bool>
+ {
+ x.f::<u32>(...) // The arguments to f are: <G, bool, u32>
+ }
+ ]}
+ *)
+ let all_generics =
+ TypesUtils.merge_generic_args
+ trait_ref.trait_decl_ref.decl_generics call.func.generics
+ in
+ log#ldebug
+ (lazy
+ ("provided method call:" ^ "\n- method name: " ^ method_name
+ ^ "\n- all_generics:\n"
+ ^ egeneric_args_to_string ctx all_generics
+ ^ "\n- parent params info: "
+ ^ Print.option_to_string A.show_params_info
+ method_def.signature.parent_params_info));
+ let tr_self =
+ T.TraitRef (etrait_ref_no_regions_to_gr_trait_ref trait_ref)
+ in
+ let inst_sg =
+ instantiate_fun_sig ctx all_generics tr_self
+ method_def.A.signature
+ in
+ (call.func.func, call.func.generics, method_def, inst_sg))
+ | _ ->
+ (* We are using a local clause - we lookup the trait decl *)
+ let trait_decl =
+ C.ctx_lookup_trait_decl ctx trait_ref.trait_decl_ref.trait_decl_id
+ in
+ (* Lookup the method decl in the required *and* the provided methods *)
+ let _, method_id =
+ let provided =
+ List.filter_map
+ (fun (id, f) ->
+ match f with None -> None | Some f -> Some (id, f))
+ trait_decl.provided_methods
+ in
+ List.find
+ (fun (s, _) -> s = method_name)
+ (List.append trait_decl.required_methods provided)
+ in
+ let method_def = C.ctx_lookup_fun_decl ctx method_id in
+ log#ldebug (lazy ("method:\n" ^ fun_decl_to_string ctx method_def));
+ (* Instantiate *)
+ let tr_self = T.TraitRef trait_ref in
+ let tr_self =
+ TypesUtils.etrait_instance_id_no_regions_to_gr_trait_instance_id
+ tr_self
+ in
+ let inst_sg =
+ instantiate_fun_sig ctx generics tr_self method_def.A.signature
+ in
+ (call.func.func, call.func.generics, method_def, inst_sg))
+ in
(* Sanity check *)
- assert (List.length args = List.length def.A.signature.inputs);
+ assert (List.length call.args = List.length def.A.signature.inputs);
(* Evaluate the function call *)
- eval_function_call_symbolic_from_inst_sig config (A.Regular fid) inst_sg
- region_args type_args cg_args args dest cf ctx
+ eval_function_call_symbolic_from_inst_sig config func inst_sg generics
+ call.args call.dest cf ctx
(** Evaluate a function call in symbolic mode by using the function signature.
This allows us to factorize the evaluation of local and non-local function
calls in symbolic mode: only their signatures matter.
+
+ The [self_trait_ref] trait ref refers to [Self]. We use it when calling
+ a provided trait method, because those methods have a special treatment:
+ we dot not group them with the required trait methods, and forbid (for now)
+ overriding them. We treat them as regular method, which take an additional
+ trait ref as input.
*)
and eval_function_call_symbolic_from_inst_sig (config : C.config)
- (fid : A.fun_id) (inst_sg : A.inst_fun_sig)
- (_region_args : T.erased_region list) (type_args : T.ety list)
- (cg_args : T.const_generic list) (args : E.operand list) (dest : E.place) :
+ (fid : A.fun_id_or_trait_method_ref) (inst_sg : A.inst_fun_sig)
+ (generics : T.egeneric_args) (args : E.operand list) (dest : E.place) :
st_cm_fun =
fun cf ctx ->
+ log#ldebug
+ (lazy
+ ("eval_function_call_symbolic_from_inst_sig:\n- fid: "
+ ^ fun_id_or_trait_method_ref_to_string ctx fid
+ ^ "\n- inst_sg:\n"
+ ^ inst_fun_sig_to_string ctx inst_sg
+ ^ "\n- call.generics:\n"
+ ^ egeneric_args_to_string ctx generics
+ ^ "\n- args:\n"
+ ^ String.concat ", " (List.map (operand_to_string ctx) args)
+ ^ "\n- dest:\n" ^ place_to_string ctx dest));
+
(* Generate a fresh symbolic value for the return value *)
let ret_sv_ty = inst_sg.A.output in
let ret_spc = mk_fresh_symbolic_value V.FunCallRet ret_sv_ty in
@@ -1224,8 +1351,8 @@ and eval_function_call_symbolic_from_inst_sig (config : C.config)
let expr = cf ctx in
(* Synthesize the symbolic AST *)
- S.synthesize_regular_function_call fid call_id ctx abs_ids type_args cg_args
- args args_places ret_spc dest_place expr
+ S.synthesize_regular_function_call fid call_id ctx abs_ids generics args
+ args_places ret_spc dest_place expr
in
let cc = comp cc cf_call in
@@ -1286,17 +1413,18 @@ and eval_function_call_symbolic_from_inst_sig (config : C.config)
cc (cf Unit) ctx
(** Evaluate a non-local function call in symbolic mode *)
-and eval_non_local_function_call_symbolic (config : C.config)
- (fid : A.assumed_fun_id) (region_args : T.erased_region list)
- (type_args : T.ety list) (cg_args : T.const_generic list)
- (args : E.operand list) (dest : E.place) : st_cm_fun =
+and eval_assumed_function_call_symbolic (config : C.config)
+ (fid : A.assumed_fun_id) (call : A.call) : st_cm_fun =
fun cf ctx ->
+ let generics = call.func.generics in
+ let args = call.args in
+ let dest = call.dest in
(* Sanity check: make sure the type parameters don't contain regions -
* this is a current limitation of our synthesis *)
assert (
List.for_all
(fun ty -> not (ty_has_borrows ctx.type_context.type_infos ty))
- type_args);
+ generics.types);
(* There are two cases (and this is extremely annoying):
- the function is not box_free
@@ -1304,10 +1432,10 @@ and eval_non_local_function_call_symbolic (config : C.config)
See {!eval_box_free}
*)
match fid with
- | A.BoxFree ->
+ | BoxFree ->
(* Degenerate case: box_free - note that this is not really a function
* call: no need to call a "synthesize_..." function *)
- eval_box_free config region_args type_args cg_args args dest (cf Unit) ctx
+ eval_box_free config generics args dest (cf Unit) ctx
| _ ->
(* "Normal" case: not box_free *)
(* In symbolic mode, the behaviour of a function call is completely defined
@@ -1315,59 +1443,19 @@ and eval_non_local_function_call_symbolic (config : C.config)
* instantiated signatures, and delegate the work to an auxiliary function *)
let inst_sig =
match fid with
- | A.BoxFree ->
+ | BoxFree ->
(* should have been treated above *)
raise (Failure "Unreachable")
| _ ->
- instantiate_fun_sig type_args cg_args (Assumed.get_assumed_sig fid)
+ (* There shouldn't be any reference to Self *)
+ let tr_self = T.UnknownTrait __FUNCTION__ in
+ instantiate_fun_sig ctx generics tr_self
+ (Assumed.get_assumed_fun_sig fid)
in
(* Evaluate the function call *)
- eval_function_call_symbolic_from_inst_sig config (A.Assumed fid) inst_sig
- region_args type_args cg_args args dest cf ctx
-
-(** Evaluate a non-local (i.e, assumed) function call such as [Box::deref]
- (auxiliary helper for [eval_statement]) *)
-and eval_non_local_function_call (config : C.config) (fid : A.assumed_fun_id)
- (region_args : T.erased_region list) (type_args : T.ety list)
- (cg_args : T.const_generic list) (args : E.operand list) (dest : E.place) :
- st_cm_fun =
- fun cf ctx ->
- (* Debug *)
- log#ldebug
- (lazy
- (let type_args =
- "[" ^ String.concat ", " (List.map (ety_to_string ctx) type_args) ^ "]"
- in
- let args =
- "[" ^ String.concat ", " (List.map (operand_to_string ctx) args) ^ "]"
- in
- let dest = place_to_string ctx dest in
- "eval_non_local_function_call:\n- fid:" ^ A.show_assumed_fun_id fid
- ^ "\n- type_args: " ^ type_args ^ "\n- args: " ^ args ^ "\n- dest: "
- ^ dest));
-
- match config.mode with
- | C.ConcreteMode ->
- eval_non_local_function_call_concrete config fid region_args type_args
- cg_args args dest (cf Unit) ctx
- | C.SymbolicMode ->
- eval_non_local_function_call_symbolic config fid region_args type_args
- cg_args args dest cf ctx
-
-(** Evaluate a local (i.e, not assumed) function call (auxiliary helper for
- [eval_statement]) *)
-and eval_local_function_call (config : C.config) (fid : A.FunDeclId.id)
- (region_args : T.erased_region list) (type_args : T.ety list)
- (cg_args : T.const_generic list) (args : E.operand list) (dest : E.place) :
- st_cm_fun =
- match config.mode with
- | ConcreteMode ->
- eval_local_function_call_concrete config fid region_args type_args cg_args
- args dest
- | SymbolicMode ->
- eval_local_function_call_symbolic config fid region_args type_args cg_args
- args dest
+ eval_function_call_symbolic_from_inst_sig config (FunId (Assumed fid))
+ inst_sig generics args dest cf ctx
(** Evaluate a statement seen as a function body *)
and eval_function_body (config : C.config) (body : A.statement) : st_cm_fun =
diff --git a/compiler/InterpreterStatements.mli b/compiler/InterpreterStatements.mli
index 814bc964..e65758ae 100644
--- a/compiler/InterpreterStatements.mli
+++ b/compiler/InterpreterStatements.mli
@@ -25,15 +25,6 @@ open InterpreterExpressions
*)
val pop_frame : C.config -> bool -> (V.typed_value option -> m_fun) -> m_fun
-(** Instantiate a function signature, introducing **fresh** abstraction ids and
- region ids. This is mostly used in preparation of function calls, when
- evaluating in symbolic mode of course.
-
- Note: there are no region parameters, because they should be erased.
- *)
-val instantiate_fun_sig :
- T.ety list -> T.const_generic list -> LA.fun_sig -> LA.inst_fun_sig
-
(** Helper.
Create a list of abstractions from a list of regions groups, and insert
diff --git a/compiler/InterpreterUtils.ml b/compiler/InterpreterUtils.ml
index 7bd37550..6e08e553 100644
--- a/compiler/InterpreterUtils.ml
+++ b/compiler/InterpreterUtils.ml
@@ -10,6 +10,11 @@ open TypesUtils
module PA = Print.EvalCtxLlbcAst
open Cps
+(* TODO: we should probably rename the file to ContextsUtils *)
+
+(** The local logger *)
+let log = L.interpreter_log
+
(** Some utilities *)
(** Auxiliary function - call a function which requires a continuation,
@@ -38,6 +43,20 @@ let typed_value_to_string = PA.typed_value_to_string
let typed_avalue_to_string = PA.typed_avalue_to_string
let place_to_string = PA.place_to_string
let operand_to_string = PA.operand_to_string
+let egeneric_args_to_string = PA.egeneric_args_to_string
+let rtrait_instance_id_to_string = PA.rtrait_instance_id_to_string
+let fun_sig_to_string = PA.fun_sig_to_string
+let inst_fun_sig_to_string = PA.inst_fun_sig_to_string
+
+let fun_id_or_trait_method_ref_to_string =
+ PA.fun_id_or_trait_method_ref_to_string
+
+let fun_decl_to_string = PA.fun_decl_to_string
+let call_to_string = PA.call_to_string
+
+let trait_impl_to_string ctx =
+ PA.trait_impl_to_string { ctx with type_vars = []; const_generic_vars = [] }
+
let statement_to_string ctx = PA.statement_to_string ctx "" " "
let statement_to_string_with_tab ctx = PA.statement_to_string ctx " " " "
let env_elem_to_string ctx = PA.env_elem_to_string ctx "" " "
@@ -255,7 +274,8 @@ let value_has_ret_symbolic_value_with_borrow_under_mut (ctx : C.eval_ctx)
raise Found
else ()
| V.SynthInput | V.SynthInputGivenBack | V.FunCallGivenBack
- | V.SynthRetGivenBack | V.Global | V.LoopGivenBack | V.Aggregate ->
+ | V.SynthRetGivenBack | V.Global | V.LoopGivenBack | V.Aggregate
+ | V.ConstGeneric | V.TraitConst ->
()
end
in
@@ -272,7 +292,7 @@ let rvalue_get_place (rv : E.rvalue) : E.place option =
match rv with
| Use (Copy p | Move p) -> Some p
| Use (Constant _) -> None
- | Ref (p, _) -> Some p
+ | RvRef (p, _) -> Some p
| UnaryOp _ | BinaryOp _ | Global _ | Discriminant _ | Aggregate _ -> None
(** See {!ValuesUtils.symbolic_value_has_borrows} *)
@@ -403,3 +423,103 @@ let compute_contexts_ids (ctxl : C.eval_ctx list) : ids_sets * ids_to_values =
(** Compute the sets of ids found in a context. *)
let compute_context_ids (ctx : C.eval_ctx) : ids_sets * ids_to_values =
compute_contexts_ids [ ctx ]
+
+(** **WARNING**: this function doesn't compute the normalized types
+ (for the trait type aliases). This should be computed afterwards.
+ *)
+let initialize_eval_context (ctx : C.decls_ctx)
+ (region_groups : T.RegionGroupId.id list) (type_vars : T.type_var list)
+ (const_generic_vars : T.const_generic_var list) : C.eval_ctx =
+ C.reset_global_counters ();
+ let const_generic_vars_map =
+ T.ConstGenericVarId.Map.of_list
+ (List.map
+ (fun (cg : T.const_generic_var) ->
+ let ty = TypesUtils.ety_no_regions_to_rty (T.Literal cg.ty) in
+ let cv = mk_fresh_symbolic_typed_value V.ConstGeneric ty in
+ (cg.index, cv))
+ const_generic_vars)
+ in
+ {
+ C.type_context = ctx.type_ctx;
+ C.fun_context = ctx.fun_ctx;
+ C.global_context = ctx.global_ctx;
+ C.trait_decls_context = ctx.trait_decls_ctx;
+ C.trait_impls_context = ctx.trait_impls_ctx;
+ C.region_groups;
+ C.type_vars;
+ C.const_generic_vars;
+ C.const_generic_vars_map;
+ C.norm_trait_etypes = C.ETraitTypeRefMap.empty (* Empty for now *);
+ C.norm_trait_rtypes = C.RTraitTypeRefMap.empty (* Empty for now *);
+ C.norm_trait_stypes = C.STraitTypeRefMap.empty (* Empty for now *);
+ C.env = [ C.Frame ];
+ C.ended_regions = T.RegionId.Set.empty;
+ }
+
+(** Instantiate a function signature, introducing **fresh** abstraction ids and
+ region ids. This is mostly used in preparation of function calls (when
+ evaluating in symbolic mode).
+
+ Note: there are no region parameters, because they should be erased.
+ *)
+let instantiate_fun_sig (ctx : C.eval_ctx) (generics : T.egeneric_args)
+ (tr_self : T.rtrait_instance_id) (sg : A.fun_sig) : A.inst_fun_sig =
+ log#ldebug
+ (lazy
+ ("instantiate_fun_sig:" ^ "\n- generics: "
+ ^ egeneric_args_to_string ctx generics
+ ^ "\n- tr_self: "
+ ^ rtrait_instance_id_to_string ctx tr_self
+ ^ "\n- sg: " ^ fun_sig_to_string ctx sg));
+ (* Generate fresh abstraction ids and create a substitution from region
+ * group ids to abstraction ids *)
+ let rg_abs_ids_bindings =
+ List.map
+ (fun rg ->
+ let abs_id = C.fresh_abstraction_id () in
+ (rg.T.id, abs_id))
+ sg.regions_hierarchy
+ in
+ let asubst_map : V.AbstractionId.id T.RegionGroupId.Map.t =
+ List.fold_left
+ (fun mp (rg_id, abs_id) -> T.RegionGroupId.Map.add rg_id abs_id mp)
+ T.RegionGroupId.Map.empty rg_abs_ids_bindings
+ in
+ let asubst (rg_id : T.RegionGroupId.id) : V.AbstractionId.id =
+ T.RegionGroupId.Map.find rg_id asubst_map
+ in
+ (* Generate fresh regions and their substitutions *)
+ let _, rsubst, _ = Subst.fresh_regions_with_substs sg.generics.regions in
+ (* Generate the type substitution
+ * Note that we need the substitution to map the type variables to
+ * {!rty} types (not {!ety}). In order to do that, we convert the
+ * type parameters to types with regions. This is possible only
+ * if those types don't contain any regions.
+ * This is a current limitation of the analysis: there is still some
+ * work to do to properly handle full type parametrization.
+ * *)
+ let rtype_params = List.map ety_no_regions_to_rty generics.types in
+ let tsubst = Subst.make_type_subst_from_vars sg.generics.types rtype_params in
+ let cgsubst =
+ Subst.make_const_generic_subst_from_vars sg.generics.const_generics
+ generics.const_generics
+ in
+ (* TODO: something annoying with the trait ref subst: we need to use region
+ types, but the arguments use erased regions. For now we use the fact
+ that no regions should appear inside. In the future: we should merge
+ ety and rty. *)
+ let trait_refs =
+ List.map TypesUtils.etrait_ref_no_regions_to_gr_trait_ref
+ generics.trait_refs
+ in
+ let tr_subst =
+ Subst.make_trait_subst_from_clauses sg.generics.trait_clauses trait_refs
+ in
+ (* Substitute the signature *)
+ let inst_sig =
+ AssociatedTypes.ctx_subst_norm_signature ctx asubst rsubst tsubst cgsubst
+ tr_subst tr_self sg
+ in
+ (* Return *)
+ inst_sig
diff --git a/compiler/Invariants.ml b/compiler/Invariants.ml
index f29c7f88..5c8ec7af 100644
--- a/compiler/Invariants.ml
+++ b/compiler/Invariants.ml
@@ -7,6 +7,7 @@ module V = Values
module E = Expressions
module C = Contexts
module Subst = Substitute
+module Assoc = AssociatedTypes
module A = LlbcAst
module L = Logging
open Cps
@@ -406,13 +407,14 @@ let check_typing_invariant (ctx : C.eval_ctx) : unit =
(match (tv.V.value, tv.V.ty) with
| V.Literal cv, T.Literal ty -> check_literal_type cv ty
(* ADT case *)
- | V.Adt av, T.Adt (T.AdtId def_id, regions, tys, cgs) ->
+ | V.Adt av, T.Adt (T.AdtId def_id, generics) ->
(* Retrieve the definition to check the variant id, the number of
* parameters, etc. *)
let def = C.ctx_lookup_type_decl ctx def_id in
(* Check the number of parameters *)
- assert (List.length regions = List.length def.region_params);
- assert (List.length tys = List.length def.type_params);
+ assert (
+ List.length generics.regions = List.length def.generics.regions);
+ assert (List.length generics.types = List.length def.generics.types);
(* Check that the variant id is consistent *)
(match (av.V.variant_id, def.T.kind) with
| Some variant_id, T.Enum variants ->
@@ -421,8 +423,8 @@ let check_typing_invariant (ctx : C.eval_ctx) : unit =
| _ -> raise (Failure "Erroneous typing"));
(* Check that the field types are correct *)
let field_types =
- Subst.type_decl_get_instantiated_field_etypes def av.V.variant_id
- tys cgs
+ Assoc.type_decl_get_inst_norm_field_etypes ctx def av.V.variant_id
+ generics
in
let fields_with_types =
List.combine av.V.field_values field_types
@@ -431,34 +433,31 @@ let check_typing_invariant (ctx : C.eval_ctx) : unit =
(fun ((v, ty) : V.typed_value * T.ety) -> assert (v.V.ty = ty))
fields_with_types
(* Tuple case *)
- | V.Adt av, T.Adt (T.Tuple, regions, tys, cgs) ->
- assert (regions = []);
- assert (cgs = []);
+ | V.Adt av, T.Adt (T.Tuple, generics) ->
+ assert (generics.regions = []);
+ assert (generics.const_generics = []);
assert (av.V.variant_id = None);
(* Check that the fields have the proper values - and check that there
* are as many fields as field types at the same time *)
- let fields_with_types = List.combine av.V.field_values tys in
+ let fields_with_types =
+ List.combine av.V.field_values generics.types
+ in
List.iter
(fun ((v, ty) : V.typed_value * T.ety) -> assert (v.V.ty = ty))
fields_with_types
(* Assumed type case *)
- | V.Adt av, T.Adt (T.Assumed aty_id, regions, tys, cgs) -> (
- assert (av.V.variant_id = None || aty_id = T.Option);
- match (aty_id, av.V.field_values, regions, tys, cgs) with
+ | V.Adt av, T.Adt (T.Assumed aty_id, generics) -> (
+ assert (av.V.variant_id = None);
+ match
+ ( aty_id,
+ av.V.field_values,
+ generics.regions,
+ generics.types,
+ generics.const_generics )
+ with
(* Box *)
- | T.Box, [ inner_value ], [], [ inner_ty ], []
- | T.Option, [ inner_value ], [], [ inner_ty ], [] ->
+ | T.Box, [ inner_value ], [], [ inner_ty ], [] ->
assert (inner_value.V.ty = inner_ty)
- | T.Option, _, [], [ _ ], [] ->
- (* Option::None: nothing to check *)
- ()
- | T.Vec, fvs, [], [ vec_ty ], [] ->
- List.iter
- (fun (v : V.typed_value) -> assert (v.ty = vec_ty))
- fvs
- | T.Range, [ v0; v1 ], [], [ inner_ty ], [] ->
- assert (v0.V.ty = inner_ty);
- assert (v1.V.ty = inner_ty)
| T.Array, inner_values, _, [ inner_ty ], [ cg ] ->
(* *)
assert (
@@ -520,14 +519,17 @@ let check_typing_invariant (ctx : C.eval_ctx) : unit =
(* Check the current pair (value, type) *)
(match (atv.V.value, atv.V.ty) with
(* ADT case *)
- | V.AAdt av, T.Adt (T.AdtId def_id, regions, tys, cgs) ->
+ | V.AAdt av, T.Adt (T.AdtId def_id, generics) ->
(* Retrieve the definition to check the variant id, the number of
* parameters, etc. *)
let def = C.ctx_lookup_type_decl ctx def_id in
(* Check the number of parameters *)
- assert (List.length regions = List.length def.region_params);
- assert (List.length tys = List.length def.type_params);
- assert (List.length cgs = List.length def.const_generic_params);
+ assert (
+ List.length generics.regions = List.length def.generics.regions);
+ assert (List.length generics.types = List.length def.generics.types);
+ assert (
+ List.length generics.const_generics
+ = List.length def.generics.const_generics);
(* Check that the variant id is consistent *)
(match (av.V.variant_id, def.T.kind) with
| Some variant_id, T.Enum variants ->
@@ -536,8 +538,8 @@ let check_typing_invariant (ctx : C.eval_ctx) : unit =
| _ -> raise (Failure "Erroneous typing"));
(* Check that the field types are correct *)
let field_types =
- Subst.type_decl_get_instantiated_field_rtypes def av.V.variant_id
- regions tys cgs
+ Assoc.type_decl_get_inst_norm_field_rtypes ctx def av.V.variant_id
+ generics
in
let fields_with_types =
List.combine av.V.field_values field_types
@@ -546,20 +548,28 @@ let check_typing_invariant (ctx : C.eval_ctx) : unit =
(fun ((v, ty) : V.typed_avalue * T.rty) -> assert (v.V.ty = ty))
fields_with_types
(* Tuple case *)
- | V.AAdt av, T.Adt (T.Tuple, regions, tys, cgs) ->
- assert (regions = []);
- assert (cgs = []);
+ | V.AAdt av, T.Adt (T.Tuple, generics) ->
+ assert (generics.regions = []);
+ assert (generics.const_generics = []);
assert (av.V.variant_id = None);
(* Check that the fields have the proper values - and check that there
* are as many fields as field types at the same time *)
- let fields_with_types = List.combine av.V.field_values tys in
+ let fields_with_types =
+ List.combine av.V.field_values generics.types
+ in
List.iter
(fun ((v, ty) : V.typed_avalue * T.rty) -> assert (v.V.ty = ty))
fields_with_types
(* Assumed type case *)
- | V.AAdt av, T.Adt (T.Assumed aty_id, regions, tys, cgs) -> (
+ | V.AAdt av, T.Adt (T.Assumed aty_id, generics) -> (
assert (av.V.variant_id = None);
- match (aty_id, av.V.field_values, regions, tys, cgs) with
+ match
+ ( aty_id,
+ av.V.field_values,
+ generics.regions,
+ generics.types,
+ generics.const_generics )
+ with
(* Box *)
| T.Box, [ boxed_value ], [], [ boxed_ty ], [] ->
assert (boxed_value.V.ty = boxed_ty)
diff --git a/compiler/LlbcAst.ml b/compiler/LlbcAst.ml
index f4d26e18..2db859b2 100644
--- a/compiler/LlbcAst.ml
+++ b/compiler/LlbcAst.ml
@@ -11,6 +11,7 @@ type abs_region_groups = (AbstractionId.id, RegionId.id) g_region_groups
(** A function signature, after instantiation *)
type inst_fun_sig = {
regions_hierarchy : abs_region_groups;
+ trait_type_constraints : rtrait_type_constraint list;
inputs : rty list;
output : rty;
}
diff --git a/compiler/LlbcAstUtils.ml b/compiler/LlbcAstUtils.ml
index 1111c297..0ab4ed94 100644
--- a/compiler/LlbcAstUtils.ml
+++ b/compiler/LlbcAstUtils.ml
@@ -5,10 +5,46 @@ let lookup_fun_sig (fun_id : fun_id) (fun_decls : fun_decl FunDeclId.Map.t) :
fun_sig =
match fun_id with
| Regular id -> (FunDeclId.Map.find id fun_decls).signature
- | Assumed aid -> Assumed.get_assumed_sig aid
+ | Assumed aid -> Assumed.get_assumed_fun_sig aid
let lookup_fun_name (fun_id : fun_id) (fun_decls : fun_decl FunDeclId.Map.t) :
Names.fun_name =
match fun_id with
| Regular id -> (FunDeclId.Map.find id fun_decls).name
- | Assumed aid -> Assumed.get_assumed_name aid
+ | Assumed aid -> Assumed.get_assumed_fun_name aid
+
+(** Return the opaque declarations found in the crate, which are also *not builtin*.
+
+ [filter_assumed]: if [true], do not consider as opaque the external definitions
+ that we will map to definitions from the standard library.
+
+ Remark: the list of functions also contains the list of opaque global bodies.
+ *)
+let crate_get_opaque_non_builtin_decls (k : crate) (filter_assumed : bool) :
+ T.type_decl list * fun_decl list =
+ let open ExtractBuiltin in
+ let is_opaque_fun (d : fun_decl) : bool =
+ let sname = name_to_simple_name d.name in
+ d.body = None
+ (* Something to pay attention to: we must ignore trait method *declarations*
+ (which don't have a body but must not be considered as opaque) *)
+ && (match d.kind with TraitMethodDecl _ -> false | _ -> true)
+ && ((not filter_assumed)
+ || (not (SimpleNameMap.mem sname builtin_globals_map))
+ && not (SimpleNameMap.mem sname (builtin_funs_map ())))
+ in
+ let is_opaque_type (d : T.type_decl) : bool =
+ let sname = name_to_simple_name d.name in
+ d.kind = T.Opaque
+ && ((not filter_assumed)
+ || not (SimpleNameMap.mem sname (builtin_types_map ())))
+ in
+ (* Note that by checking the function bodies we also the globals *)
+ ( List.filter is_opaque_type (T.TypeDeclId.Map.values k.types),
+ List.filter is_opaque_fun (FunDeclId.Map.values k.functions) )
+
+(** Return true if the crate contains opaque declarations, ignoring the assumed
+ definitions. *)
+let crate_has_opaque_non_builtin_decls (k : crate) (filter_assumed : bool) :
+ bool =
+ crate_get_opaque_non_builtin_decls k filter_assumed <> ([], [])
diff --git a/compiler/Logging.ml b/compiler/Logging.ml
index 9dc1f5e3..721655b8 100644
--- a/compiler/Logging.ml
+++ b/compiler/Logging.ml
@@ -9,6 +9,9 @@ let pre_passes_log = L.get_logger "MainLogger.PrePasses"
(** Logger for Translate *)
let translate_log = L.get_logger "MainLogger.Translate"
+(** Logger for Contexts *)
+let contexts_log = L.get_logger "MainLogger.Contexts"
+
(** Logger for PureUtils *)
let pure_utils_log = L.get_logger "MainLogger.PureUtils"
@@ -19,7 +22,7 @@ let symbolic_to_pure_log = L.get_logger "MainLogger.SymbolicToPure"
let pure_micro_passes_log = L.get_logger "MainLogger.PureMicroPasses"
(** Logger for ExtractBase *)
-let pure_to_extract_log = L.get_logger "MainLogger.ExtractBase"
+let extract_log = L.get_logger "MainLogger.ExtractBase"
(** Logger for Interpreter *)
let interpreter_log = L.get_logger "MainLogger.Interpreter"
@@ -57,6 +60,9 @@ let borrows_log = L.get_logger "MainLogger.Interpreter.Borrows"
(** Logger for Invariants *)
let invariants_log = L.get_logger "MainLogger.Interpreter.Invariants"
+(** Logger for AssociatedTypes *)
+let associated_types_log = L.get_logger "MainLogger.AssociatedTypes"
+
(** Logger for SCC *)
let scc_log = L.get_logger "MainLogger.Graph.SCC"
diff --git a/compiler/PrePasses.ml b/compiler/PrePasses.ml
index b348ba1d..ee06fa07 100644
--- a/compiler/PrePasses.ml
+++ b/compiler/PrePasses.ml
@@ -107,8 +107,8 @@ let remove_useless_cf_merges (crate : A.crate) (f : A.fun_decl) : A.fun_decl =
false
| Assign (_, rv) -> (
match rv with
- | Use _ | Ref _ -> not must_end_with_exit
- | Aggregate (AggregatedTuple, []) -> not must_end_with_exit
+ | Use _ | RvRef _ -> not must_end_with_exit
+ | Aggregate (AggregatedAdt (Tuple, _, _), []) -> not must_end_with_exit
| _ -> false)
| FakeRead _ | Drop _ | Nop -> not must_end_with_exit
| Panic | Return -> true
@@ -376,7 +376,7 @@ let remove_shallow_borrows (crate : A.crate) (f : A.fun_decl) : A.fun_decl =
method! visit_Assign env p rv =
match (p.projection, rv) with
- | [], E.Ref (_, E.Shallow) ->
+ | [], E.RvRef (_, E.Shallow) ->
(* Filter *)
filtered := E.VarId.Set.add p.var_id !filtered;
Nop
diff --git a/compiler/Print.ml b/compiler/Print.ml
index 9aa73d7c..7f0d95ff 100644
--- a/compiler/Print.ml
+++ b/compiler/Print.ml
@@ -21,6 +21,9 @@ module Values = struct
type_decl_id_to_string : T.TypeDeclId.id -> string;
const_generic_var_id_to_string : T.ConstGenericVarId.id -> string;
global_decl_id_to_string : T.GlobalDeclId.id -> string;
+ trait_decl_id_to_string : T.TraitDeclId.id -> string;
+ trait_impl_id_to_string : T.TraitImplId.id -> string;
+ trait_clause_id_to_string : T.TraitClauseId.id -> string;
adt_variant_to_string : T.TypeDeclId.id -> T.VariantId.id -> string;
var_id_to_string : E.VarId.id -> string;
adt_field_names :
@@ -34,6 +37,9 @@ module Values = struct
PT.type_decl_id_to_string = fmt.type_decl_id_to_string;
PT.const_generic_var_id_to_string = fmt.const_generic_var_id_to_string;
PT.global_decl_id_to_string = fmt.global_decl_id_to_string;
+ PT.trait_decl_id_to_string = fmt.trait_decl_id_to_string;
+ PT.trait_impl_id_to_string = fmt.trait_impl_id_to_string;
+ PT.trait_clause_id_to_string = fmt.trait_clause_id_to_string;
}
let value_to_rtype_formatter (fmt : value_formatter) : PT.rtype_formatter =
@@ -43,6 +49,9 @@ module Values = struct
PT.type_decl_id_to_string = fmt.type_decl_id_to_string;
PT.const_generic_var_id_to_string = fmt.const_generic_var_id_to_string;
PT.global_decl_id_to_string = fmt.global_decl_id_to_string;
+ PT.trait_decl_id_to_string = fmt.trait_decl_id_to_string;
+ PT.trait_impl_id_to_string = fmt.trait_impl_id_to_string;
+ PT.trait_clause_id_to_string = fmt.trait_clause_id_to_string;
}
let value_to_stype_formatter (fmt : value_formatter) : PT.stype_formatter =
@@ -52,6 +61,9 @@ module Values = struct
PT.type_decl_id_to_string = fmt.type_decl_id_to_string;
PT.const_generic_var_id_to_string = fmt.const_generic_var_id_to_string;
PT.global_decl_id_to_string = fmt.global_decl_id_to_string;
+ PT.trait_decl_id_to_string = fmt.trait_decl_id_to_string;
+ PT.trait_impl_id_to_string = fmt.trait_impl_id_to_string;
+ PT.trait_clause_id_to_string = fmt.trait_clause_id_to_string;
}
let var_id_to_string (id : E.VarId.id) : string =
@@ -86,10 +98,10 @@ module Values = struct
List.map (typed_value_to_string fmt) av.field_values
in
match v.ty with
- | T.Adt (T.Tuple, _, _, _) ->
+ | T.Adt (T.Tuple, _) ->
(* Tuple *)
"(" ^ String.concat ", " field_values ^ ")"
- | T.Adt (T.AdtId def_id, _, _, _) ->
+ | T.Adt (T.AdtId def_id, _) ->
(* "Regular" ADT *)
let adt_ident =
match av.variant_id with
@@ -111,21 +123,10 @@ module Values = struct
let field_values = String.concat " " field_values in
adt_ident ^ " { " ^ field_values ^ " }"
else adt_ident
- | T.Adt (T.Assumed aty, _, _, _) -> (
+ | T.Adt (T.Assumed aty, _) -> (
(* Assumed type *)
match (aty, field_values) with
| Box, [ bv ] -> "@Box(" ^ bv ^ ")"
- | Option, _ ->
- if av.variant_id = Some T.option_some_id then
- "@Option::Some("
- ^ Collections.List.to_cons_nil field_values
- ^ ")"
- else if av.variant_id = Some T.option_none_id then (
- assert (field_values = []);
- "@Option::None")
- else raise (Failure "Unreachable")
- | Range, _ -> "@Range{ " ^ String.concat ", " field_values ^ "}"
- | Vec, _ -> "@Vec[" ^ String.concat ", " field_values ^ "]"
| Array, _ ->
(* Happens when we aggregate values *)
"@Array[" ^ String.concat ", " field_values ^ "]"
@@ -201,10 +202,10 @@ module Values = struct
List.map (typed_avalue_to_string fmt) av.field_values
in
match v.ty with
- | T.Adt (T.Tuple, _, _, _) ->
+ | T.Adt (T.Tuple, _) ->
(* Tuple *)
"(" ^ String.concat ", " field_values ^ ")"
- | T.Adt (T.AdtId def_id, _, _, _) ->
+ | T.Adt (T.AdtId def_id, _) ->
(* "Regular" ADT *)
let adt_ident =
match av.variant_id with
@@ -226,7 +227,7 @@ module Values = struct
let field_values = String.concat " " field_values in
adt_ident ^ " { " ^ field_values ^ " }"
else adt_ident
- | T.Adt (T.Assumed aty, _, _, _) -> (
+ | T.Adt (T.Assumed aty, _) -> (
(* Assumed type *)
match (aty, field_values) with
| Box, [ bv ] -> "@Box(" ^ bv ^ ")"
@@ -347,6 +348,18 @@ module Values = struct
^ "}" ^ "{regions="
^ T.RegionId.Set.to_string None abs.regions
^ "}" ^ " {\n" ^ avs ^ "\n" ^ indent ^ "}"
+
+ let inst_fun_sig_to_string (fmt : value_formatter) (sg : LlbcAst.inst_fun_sig)
+ : string =
+ (* TODO: print the trait type constraints? *)
+ let ty_fmt = value_to_rtype_formatter fmt in
+ let ty_to_string = PT.ty_to_string ty_fmt in
+
+ let inputs =
+ "(" ^ String.concat ", " (List.map ty_to_string sg.inputs) ^ ")"
+ in
+ let output = ty_to_string sg.output in
+ inputs ^ " -> " ^ output
end
module PV = Values (* local module *)
@@ -452,6 +465,9 @@ module Contexts = struct
PV.adt_variant_to_string = fmt.adt_variant_to_string;
PV.var_id_to_string = fmt.var_id_to_string;
PV.adt_field_names = fmt.adt_field_names;
+ PV.trait_decl_id_to_string = fmt.trait_decl_id_to_string;
+ PV.trait_impl_id_to_string = fmt.trait_impl_id_to_string;
+ PV.trait_clause_id_to_string = fmt.trait_clause_id_to_string;
}
let ast_to_value_formatter (fmt : PA.ast_formatter) : PV.value_formatter =
@@ -463,20 +479,27 @@ module Contexts = struct
let ctx_to_rtype_formatter (fmt : ctx_formatter) : PT.rtype_formatter =
PV.value_to_rtype_formatter fmt
+ let ctx_to_stype_formatter (fmt : ctx_formatter) : PT.stype_formatter =
+ PV.value_to_stype_formatter fmt
+
let eval_ctx_to_ctx_formatter (ctx : C.eval_ctx) : ctx_formatter =
- (* We shouldn't use rvar_to_string *)
- let rvar_to_string _r =
- raise (Failure "Unexpected use of rvar_to_string")
+ let rvar_to_string r =
+ (* In theory we shouldn't use rvar_to_string, but it can happen
+ when printing definitions for instance... *)
+ T.RegionVarId.to_string r
in
let r_to_string r = PT.region_id_to_string r in
let type_var_id_to_string vid =
- let v = C.lookup_type_var ctx vid in
- v.name
+ (* The context may be invalid *)
+ match C.lookup_type_var_opt ctx vid with
+ | None -> T.TypeVarId.to_string vid
+ | Some v -> v.name
in
let const_generic_var_id_to_string vid =
- let v = C.lookup_const_generic_var ctx vid in
- v.name
+ match C.lookup_const_generic_var_opt ctx vid with
+ | None -> T.ConstGenericVarId.to_string vid
+ | Some v -> v.name
in
let type_decl_id_to_string def_id =
let def = C.ctx_lookup_type_decl ctx def_id in
@@ -486,6 +509,15 @@ module Contexts = struct
let def = C.ctx_lookup_global_decl ctx def_id in
name_to_string def.name
in
+ let trait_decl_id_to_string def_id =
+ let def = C.ctx_lookup_trait_decl ctx def_id in
+ name_to_string def.name
+ in
+ let trait_impl_id_to_string def_id =
+ let def = C.ctx_lookup_trait_impl ctx def_id in
+ name_to_string def.name
+ in
+ let trait_clause_id_to_string id = PT.trait_clause_id_to_pretty_string id in
let adt_variant_to_string =
PT.type_ctx_to_adt_variant_to_string_fun ctx.type_context.type_decls
in
@@ -506,6 +538,9 @@ module Contexts = struct
adt_variant_to_string;
var_id_to_string;
adt_field_names;
+ trait_decl_id_to_string;
+ trait_impl_id_to_string;
+ trait_clause_id_to_string;
}
let eval_ctx_to_ast_formatter (ctx : C.eval_ctx) : PA.ast_formatter =
@@ -521,6 +556,15 @@ module Contexts = struct
let def = C.ctx_lookup_global_decl ctx def_id in
global_name_to_string def.name
in
+ let trait_decl_id_to_string def_id =
+ let def = C.ctx_lookup_trait_decl ctx def_id in
+ name_to_string def.name
+ in
+ let trait_impl_id_to_string def_id =
+ let def = C.ctx_lookup_trait_impl ctx def_id in
+ name_to_string def.name
+ in
+ let trait_clause_id_to_string id = PT.trait_clause_id_to_pretty_string id in
{
rvar_to_string = ctx_fmt.PV.rvar_to_string;
r_to_string = ctx_fmt.PV.r_to_string;
@@ -533,6 +577,9 @@ module Contexts = struct
adt_field_to_string;
fun_decl_id_to_string;
global_decl_id_to_string;
+ trait_decl_id_to_string;
+ trait_impl_id_to_string;
+ trait_clause_id_to_string;
}
(** Split an [env] at every occurrence of [Frame], eliminating those elements.
@@ -608,6 +655,68 @@ module EvalCtxLlbcAst = struct
let fmt = PC.ctx_to_rtype_formatter fmt in
PT.rty_to_string fmt t
+ let sty_to_string (ctx : C.eval_ctx) (t : T.sty) : string =
+ let fmt = PC.eval_ctx_to_ctx_formatter ctx in
+ let fmt = PC.ctx_to_stype_formatter fmt in
+ PT.sty_to_string fmt t
+
+ let generic_params_to_strings (ctx : C.eval_ctx) (x : T.generic_params) :
+ string list * string list =
+ let fmt = PC.eval_ctx_to_ctx_formatter ctx in
+ let fmt = PC.ctx_to_stype_formatter fmt in
+ PT.generic_params_to_strings fmt x
+
+ let egeneric_args_to_string (ctx : C.eval_ctx) (x : T.egeneric_args) : string
+ =
+ let fmt = PC.eval_ctx_to_ctx_formatter ctx in
+ let fmt = PC.ctx_to_etype_formatter fmt in
+ PT.egeneric_args_to_string fmt x
+
+ let rgeneric_args_to_string (ctx : C.eval_ctx) (x : T.rgeneric_args) : string
+ =
+ let fmt = PC.eval_ctx_to_ctx_formatter ctx in
+ let fmt = PC.ctx_to_rtype_formatter fmt in
+ PT.rgeneric_args_to_string fmt x
+
+ let sgeneric_args_to_string (ctx : C.eval_ctx) (x : T.sgeneric_args) : string
+ =
+ let fmt = PC.eval_ctx_to_ctx_formatter ctx in
+ let fmt = PC.ctx_to_stype_formatter fmt in
+ PT.sgeneric_args_to_string fmt x
+
+ let etrait_ref_to_string (ctx : C.eval_ctx) (x : T.etrait_ref) : string =
+ let fmt = PC.eval_ctx_to_ctx_formatter ctx in
+ let fmt = PC.ctx_to_etype_formatter fmt in
+ PT.etrait_ref_to_string fmt x
+
+ let rtrait_ref_to_string (ctx : C.eval_ctx) (x : T.rtrait_ref) : string =
+ let fmt = PC.eval_ctx_to_ctx_formatter ctx in
+ let fmt = PC.ctx_to_rtype_formatter fmt in
+ PT.rtrait_ref_to_string fmt x
+
+ let strait_ref_to_string (ctx : C.eval_ctx) (x : T.strait_ref) : string =
+ let fmt = PC.eval_ctx_to_ctx_formatter ctx in
+ let fmt = PC.ctx_to_stype_formatter fmt in
+ PT.strait_ref_to_string fmt x
+
+ let etrait_instance_id_to_string (ctx : C.eval_ctx) (x : T.etrait_instance_id)
+ : string =
+ let fmt = PC.eval_ctx_to_ctx_formatter ctx in
+ let fmt = PC.ctx_to_etype_formatter fmt in
+ PT.etrait_instance_id_to_string fmt x
+
+ let rtrait_instance_id_to_string (ctx : C.eval_ctx) (x : T.rtrait_instance_id)
+ : string =
+ let fmt = PC.eval_ctx_to_ctx_formatter ctx in
+ let fmt = PC.ctx_to_rtype_formatter fmt in
+ PT.rtrait_instance_id_to_string fmt x
+
+ let strait_instance_id_to_string (ctx : C.eval_ctx) (x : T.strait_instance_id)
+ : string =
+ let fmt = PC.eval_ctx_to_ctx_formatter ctx in
+ let fmt = PC.ctx_to_stype_formatter fmt in
+ PT.strait_instance_id_to_string fmt x
+
let borrow_content_to_string (ctx : C.eval_ctx) (bc : V.borrow_content) :
string =
let fmt = PC.eval_ctx_to_ctx_formatter ctx in
@@ -653,11 +762,38 @@ module EvalCtxLlbcAst = struct
let fmt = PC.eval_ctx_to_ast_formatter ctx in
PE.operand_to_string fmt op
+ let call_to_string (ctx : C.eval_ctx) (call : A.call) : string =
+ let fmt = PC.eval_ctx_to_ast_formatter ctx in
+ PA.call_to_string fmt "" call
+
+ let fun_decl_to_string (ctx : C.eval_ctx) (f : A.fun_decl) : string =
+ let fmt = PC.eval_ctx_to_ast_formatter ctx in
+ PA.fun_decl_to_string fmt "" " " f
+
+ let fun_sig_to_string (ctx : C.eval_ctx) (x : A.fun_sig) : string =
+ let fmt = PC.eval_ctx_to_ast_formatter ctx in
+ PA.fun_sig_to_string fmt "" " " x
+
+ let inst_fun_sig_to_string (ctx : C.eval_ctx) (x : LlbcAst.inst_fun_sig) :
+ string =
+ let fmt = PC.eval_ctx_to_ast_formatter ctx in
+ let fmt = PC.ast_to_value_formatter fmt in
+ PV.inst_fun_sig_to_string fmt x
+
+ let fun_id_or_trait_method_ref_to_string (ctx : C.eval_ctx)
+ (x : E.fun_id_or_trait_method_ref) : string =
+ let fmt = PC.eval_ctx_to_ast_formatter ctx in
+ PE.fun_id_or_trait_method_ref_to_string fmt x "..."
+
let statement_to_string (ctx : C.eval_ctx) (indent : string)
(indent_incr : string) (e : A.statement) : string =
let fmt = PC.eval_ctx_to_ast_formatter ctx in
PA.statement_to_string fmt indent indent_incr e
+ let trait_impl_to_string (ctx : C.eval_ctx) (timpl : A.trait_impl) : string =
+ let fmt = PC.eval_ctx_to_ast_formatter ctx in
+ PA.trait_impl_to_string fmt " " " " timpl
+
let env_elem_to_string (ctx : C.eval_ctx) (indent : string)
(indent_incr : string) (ev : C.env_elem) : string =
let fmt = PC.eval_ctx_to_ctx_formatter ctx in
diff --git a/compiler/PrintPure.ml b/compiler/PrintPure.ml
index cfb63ec2..ec75fcfd 100644
--- a/compiler/PrintPure.ml
+++ b/compiler/PrintPure.ml
@@ -8,6 +8,9 @@ type type_formatter = {
type_decl_id_to_string : TypeDeclId.id -> string;
const_generic_var_id_to_string : ConstGenericVarId.id -> string;
global_decl_id_to_string : GlobalDeclId.id -> string;
+ trait_decl_id_to_string : TraitDeclId.id -> string;
+ trait_impl_id_to_string : TraitImplId.id -> string;
+ trait_clause_id_to_string : TraitClauseId.id -> string;
}
type value_formatter = {
@@ -18,6 +21,9 @@ type value_formatter = {
adt_variant_to_string : TypeDeclId.id -> VariantId.id -> string;
var_id_to_string : VarId.id -> string;
adt_field_names : TypeDeclId.id -> VariantId.id option -> string list option;
+ trait_decl_id_to_string : TraitDeclId.id -> string;
+ trait_impl_id_to_string : TraitImplId.id -> string;
+ trait_clause_id_to_string : TraitClauseId.id -> string;
}
let value_to_type_formatter (fmt : value_formatter) : type_formatter =
@@ -26,6 +32,9 @@ let value_to_type_formatter (fmt : value_formatter) : type_formatter =
type_decl_id_to_string = fmt.type_decl_id_to_string;
const_generic_var_id_to_string = fmt.const_generic_var_id_to_string;
global_decl_id_to_string = fmt.global_decl_id_to_string;
+ trait_decl_id_to_string = fmt.trait_decl_id_to_string;
+ trait_impl_id_to_string = fmt.trait_impl_id_to_string;
+ trait_clause_id_to_string = fmt.trait_clause_id_to_string;
}
(* TODO: we need to store which variables we have encountered so far, and
@@ -42,6 +51,9 @@ type ast_formatter = {
adt_field_names : TypeDeclId.id -> VariantId.id option -> string list option;
fun_decl_id_to_string : FunDeclId.id -> string;
global_decl_id_to_string : GlobalDeclId.id -> string;
+ trait_decl_id_to_string : TraitDeclId.id -> string;
+ trait_impl_id_to_string : TraitImplId.id -> string;
+ trait_clause_id_to_string : TraitClauseId.id -> string;
}
let ast_to_value_formatter (fmt : ast_formatter) : value_formatter =
@@ -53,6 +65,9 @@ let ast_to_value_formatter (fmt : ast_formatter) : value_formatter =
adt_variant_to_string = fmt.adt_variant_to_string;
var_id_to_string = fmt.var_id_to_string;
adt_field_names = fmt.adt_field_names;
+ trait_decl_id_to_string = fmt.trait_decl_id_to_string;
+ trait_impl_id_to_string = fmt.trait_impl_id_to_string;
+ trait_clause_id_to_string = fmt.trait_clause_id_to_string;
}
let ast_to_type_formatter (fmt : ast_formatter) : type_formatter =
@@ -70,31 +85,51 @@ let literal_type_to_string = Print.PrimitiveValues.literal_type_to_string
let scalar_value_to_string = Print.PrimitiveValues.scalar_value_to_string
let literal_to_string = Print.PrimitiveValues.literal_to_string
+(* Remark: not using generic_params on purpose, because we may use parameters
+ which either come from LLBC or from pure, and the [generic_params] type
+ for those ASTs is not the same. Note that it works because we actually don't
+ need to know the trait clauses to print the AST: we can thus ignore them.
+*)
let mk_type_formatter (type_decls : T.type_decl TypeDeclId.Map.t)
(global_decls : A.global_decl GlobalDeclId.Map.t)
- (type_params : type_var list)
+ (trait_decls : A.trait_decl TraitDeclId.Map.t)
+ (trait_impls : A.trait_impl TraitImplId.Map.t) (type_params : type_var list)
(const_generic_params : const_generic_var list) : type_formatter =
let type_var_id_to_string vid =
- let var = T.TypeVarId.nth type_params vid in
+ let var = TypeVarId.nth type_params vid in
type_var_to_string var
in
let const_generic_var_id_to_string vid =
- let var = T.ConstGenericVarId.nth const_generic_params vid in
+ let var = ConstGenericVarId.nth const_generic_params vid in
const_generic_var_to_string var
in
let type_decl_id_to_string def_id =
- let def = T.TypeDeclId.Map.find def_id type_decls in
+ let def = TypeDeclId.Map.find def_id type_decls in
name_to_string def.name
in
let global_decl_id_to_string def_id =
- let def = T.GlobalDeclId.Map.find def_id global_decls in
+ let def = GlobalDeclId.Map.find def_id global_decls in
+ name_to_string def.name
+ in
+ let trait_decl_id_to_string def_id =
+ let def = TraitDeclId.Map.find def_id trait_decls in
+ name_to_string def.name
+ in
+ let trait_impl_id_to_string def_id =
+ let def = TraitImplId.Map.find def_id trait_impls in
name_to_string def.name
in
+ let trait_clause_id_to_string id =
+ Print.PT.trait_clause_id_to_pretty_string id
+ in
{
type_var_id_to_string;
type_decl_id_to_string;
const_generic_var_id_to_string;
global_decl_id_to_string;
+ trait_decl_id_to_string;
+ trait_impl_id_to_string;
+ trait_clause_id_to_string;
}
(* TODO: there is a bit of duplication with Print.fun_decl_to_ast_formatter.
@@ -106,19 +141,21 @@ let mk_type_formatter (type_decls : T.type_decl TypeDeclId.Map.t)
let mk_ast_formatter (type_decls : T.type_decl TypeDeclId.Map.t)
(fun_decls : A.fun_decl FunDeclId.Map.t)
(global_decls : A.global_decl GlobalDeclId.Map.t)
- (type_params : type_var list)
+ (trait_decls : A.trait_decl TraitDeclId.Map.t)
+ (trait_impls : A.trait_impl TraitImplId.Map.t) (type_params : type_var list)
(const_generic_params : const_generic_var list) : ast_formatter =
- let type_var_id_to_string vid =
- let var = T.TypeVarId.nth type_params vid in
- type_var_to_string var
- in
- let const_generic_var_id_to_string vid =
- let var = T.ConstGenericVarId.nth const_generic_params vid in
- const_generic_var_to_string var
- in
- let type_decl_id_to_string def_id =
- let def = T.TypeDeclId.Map.find def_id type_decls in
- name_to_string def.name
+ let ({
+ type_var_id_to_string;
+ type_decl_id_to_string;
+ const_generic_var_id_to_string;
+ global_decl_id_to_string;
+ trait_decl_id_to_string;
+ trait_impl_id_to_string;
+ trait_clause_id_to_string;
+ }
+ : type_formatter) =
+ mk_type_formatter type_decls global_decls trait_decls trait_impls
+ type_params const_generic_params
in
let adt_variant_to_string =
Print.Types.type_ctx_to_adt_variant_to_string_fun type_decls
@@ -137,10 +174,6 @@ let mk_ast_formatter (type_decls : T.type_decl TypeDeclId.Map.t)
let def = FunDeclId.Map.find def_id fun_decls in
fun_name_to_string def.name
in
- let global_decl_id_to_string def_id =
- let def = GlobalDeclId.Map.find def_id global_decls in
- global_name_to_string def.name
- in
{
type_var_id_to_string;
const_generic_var_id_to_string;
@@ -151,6 +184,9 @@ let mk_ast_formatter (type_decls : T.type_decl TypeDeclId.Map.t)
adt_field_to_string;
fun_decl_id_to_string;
global_decl_id_to_string;
+ trait_decl_id_to_string;
+ trait_impl_id_to_string;
+ trait_clause_id_to_string;
}
let assumed_ty_to_string (aty : assumed_ty) : string =
@@ -159,12 +195,11 @@ let assumed_ty_to_string (aty : assumed_ty) : string =
| Result -> "Result"
| Error -> "Error"
| Fuel -> "Fuel"
- | Option -> "Option"
- | Vec -> "Vec"
| Array -> "Array"
| Slice -> "Slice"
| Str -> "Str"
- | Range -> "Range"
+ | RawPtr Mut -> "MutRawPtr"
+ | RawPtr Const -> "ConstRawPtr"
let type_id_to_string (fmt : type_formatter) (id : type_id) : string =
match id with
@@ -182,20 +217,18 @@ let const_generic_to_string (fmt : type_formatter) (cg : T.const_generic) :
let rec ty_to_string (fmt : type_formatter) (inside : bool) (ty : ty) : string =
match ty with
- | Adt (id, tys, cgs) -> (
- let tys = List.map (ty_to_string fmt false) tys in
- let cgs = List.map (const_generic_to_string fmt) cgs in
- let params = List.append tys cgs in
+ | Adt (id, generics) -> (
match id with
| Tuple ->
- assert (cgs = []);
- "(" ^ String.concat " * " tys ^ ")"
+ let generics = generic_args_to_strings fmt false generics in
+ "(" ^ String.concat " * " generics ^ ")"
| AdtId _ | Assumed _ ->
- let params_s =
- if params = [] then "" else " " ^ String.concat " " params
+ let generics = generic_args_to_strings fmt true generics in
+ let generics_s =
+ if generics = [] then "" else " " ^ String.concat " " generics
in
- let ty_s = type_id_to_string fmt id ^ params_s in
- if params <> [] && inside then "(" ^ ty_s ^ ")" else ty_s)
+ let ty_s = type_id_to_string fmt id ^ generics_s in
+ if generics <> [] && inside then "(" ^ ty_s ^ ")" else ty_s)
| TypeVar tv -> fmt.type_var_id_to_string tv
| Literal lty -> literal_type_to_string lty
| Arrow (arg_ty, ret_ty) ->
@@ -203,6 +236,71 @@ let rec ty_to_string (fmt : type_formatter) (inside : bool) (ty : ty) : string =
ty_to_string fmt true arg_ty ^ " -> " ^ ty_to_string fmt false ret_ty
in
if inside then "(" ^ ty ^ ")" else ty
+ | TraitType (trait_ref, generics, type_name) ->
+ let trait_ref = trait_ref_to_string fmt false trait_ref in
+ let s =
+ if generics = empty_generic_args then trait_ref ^ "::" ^ type_name
+ else
+ let generics = generic_args_to_string fmt generics in
+ "(" ^ trait_ref ^ " " ^ generics ^ ")::" ^ type_name
+ in
+ if inside then "(" ^ s ^ ")" else s
+
+and generic_args_to_strings (fmt : type_formatter) (inside : bool)
+ (generics : generic_args) : string list =
+ let tys = List.map (ty_to_string fmt inside) generics.types in
+ let cgs = List.map (const_generic_to_string fmt) generics.const_generics in
+ let trait_refs =
+ List.map (trait_ref_to_string fmt inside) generics.trait_refs
+ in
+ List.concat [ tys; cgs; trait_refs ]
+
+and generic_args_to_string (fmt : type_formatter) (generics : generic_args) :
+ string =
+ String.concat " " (generic_args_to_strings fmt true generics)
+
+and trait_ref_to_string (fmt : type_formatter) (inside : bool) (tr : trait_ref)
+ : string =
+ let trait_id = trait_instance_id_to_string fmt false tr.trait_id in
+ let generics = generic_args_to_string fmt tr.generics in
+ let s = trait_id ^ generics in
+ if tr.generics = empty_generic_args || not inside then s else "(" ^ s ^ ")"
+
+and trait_instance_id_to_string (fmt : type_formatter) (inside : bool)
+ (id : trait_instance_id) : string =
+ match id with
+ | Self -> "Self"
+ | TraitImpl id -> fmt.trait_impl_id_to_string id
+ | Clause id -> fmt.trait_clause_id_to_string id
+ | ParentClause (inst_id, _decl_id, clause_id) ->
+ let inst_id = trait_instance_id_to_string fmt false inst_id in
+ let clause_id = fmt.trait_clause_id_to_string clause_id in
+ "parent(" ^ inst_id ^ ")::" ^ clause_id
+ | ItemClause (inst_id, _decl_id, item_name, clause_id) ->
+ let inst_id = trait_instance_id_to_string fmt false inst_id in
+ let clause_id = fmt.trait_clause_id_to_string clause_id in
+ "(" ^ inst_id ^ ")::" ^ item_name ^ "::[" ^ clause_id ^ "]"
+ | TraitRef tr -> trait_ref_to_string fmt inside tr
+ | UnknownTrait msg -> "UNKNOWN(" ^ msg ^ ")"
+
+let trait_clause_to_string (fmt : type_formatter) (clause : trait_clause) :
+ string =
+ let clause_id = fmt.trait_clause_id_to_string clause.clause_id in
+ let trait_id = fmt.trait_decl_id_to_string clause.trait_id in
+ let generics = generic_args_to_strings fmt true clause.generics in
+ let generics =
+ if generics = [] then "" else " " ^ String.concat " " generics
+ in
+ "[" ^ clause_id ^ "]: " ^ trait_id ^ generics
+
+let generic_params_to_strings (fmt : type_formatter) (generics : generic_params)
+ : string list =
+ let tys = List.map type_var_to_string generics.types in
+ let cgs = List.map const_generic_var_to_string generics.const_generics in
+ let trait_clauses =
+ List.map (trait_clause_to_string fmt) generics.trait_clauses
+ in
+ List.concat [ tys; cgs; trait_clauses ]
let field_to_string fmt inside (f : field) : string =
match f.field_name with
@@ -217,11 +315,10 @@ let variant_to_string fmt (v : variant) : string =
^ ")"
let type_decl_to_string (fmt : type_formatter) (def : type_decl) : string =
- let types = def.type_params in
let name = name_to_string def.name in
let params =
- if types = [] then ""
- else " " ^ String.concat " " (List.map type_var_to_string types)
+ if def.generics = empty_generic_params then ""
+ else " " ^ String.concat " " (generic_params_to_strings fmt def.generics)
in
match def.kind with
| Struct fields ->
@@ -256,10 +353,6 @@ let rec mprojection_to_string (fmt : ast_formatter) (inside : string)
| pe :: p' -> (
let s = mprojection_to_string fmt inside p' in
match pe.pkind with
- | E.ProjOption variant_id ->
- assert (variant_id = T.option_some_id);
- assert (pe.field_id = T.FieldId.zero);
- "(" ^ s ^ "as Option::Some)." ^ T.FieldId.to_string pe.field_id
| E.ProjTuple _ -> "(" ^ s ^ ")." ^ T.FieldId.to_string pe.field_id
| E.ProjAdt (adt_id, opt_variant_id) -> (
let field_name =
@@ -294,11 +387,9 @@ let adt_variant_to_string (fmt : value_formatter) (adt_id : type_id)
| Assumed aty -> (
(* Assumed type *)
match aty with
- | State | Array | Slice | Str ->
+ | State | Array | Slice | Str | RawPtr _ ->
(* Those types are opaque: we can't get there *)
raise (Failure "Unreachable")
- | Vec -> "@Vec"
- | Range -> "@Range"
| Result ->
let variant_id = Option.get variant_id in
if variant_id = result_return_id then "@Result::Return"
@@ -314,13 +405,7 @@ let adt_variant_to_string (fmt : value_formatter) (adt_id : type_id)
let variant_id = Option.get variant_id in
if variant_id = fuel_zero_id then "@Fuel::Zero"
else if variant_id = fuel_succ_id then "@Fuel::Succ"
- else raise (Failure "Unreachable: improper variant id for fuel type")
- | Option ->
- let variant_id = Option.get variant_id in
- if variant_id = option_some_id then "@Option::Some "
- else if variant_id = option_none_id then "@Option::None"
- else
- raise (Failure "Unreachable: improper variant id for result type"))
+ else raise (Failure "Unreachable: improper variant id for fuel type"))
let adt_field_to_string (fmt : value_formatter) (adt_id : type_id)
(field_id : FieldId.id) : string =
@@ -337,11 +422,10 @@ let adt_field_to_string (fmt : value_formatter) (adt_id : type_id)
| Assumed aty -> (
(* Assumed type *)
match aty with
- | Range -> FieldId.to_string field_id
- | State | Fuel | Vec | Array | Slice | Str ->
+ | State | Fuel | Array | Slice | Str ->
(* Opaque types: we can't get there *)
raise (Failure "Unreachable")
- | Result | Error | Option ->
+ | Result | Error | RawPtr _ ->
(* Enumerations: we can't get there *)
raise (Failure "Unreachable"))
@@ -353,10 +437,10 @@ let adt_g_value_to_string (fmt : value_formatter)
(field_values : 'v list) (ty : ty) : string =
let field_values = List.map value_to_string field_values in
match ty with
- | Adt (Tuple, _, _) ->
+ | Adt (Tuple, _) ->
(* Tuple *)
"(" ^ String.concat ", " field_values ^ ")"
- | Adt (AdtId def_id, _, _) ->
+ | Adt (AdtId def_id, _) ->
(* "Regular" ADT *)
let adt_ident =
match variant_id with
@@ -378,10 +462,10 @@ let adt_g_value_to_string (fmt : value_formatter)
let field_values = String.concat " " field_values in
adt_ident ^ " { " ^ field_values ^ " }"
else adt_ident
- | Adt (Assumed aty, _, _) -> (
+ | Adt (Assumed aty, _) -> (
(* Assumed type *)
match aty with
- | State ->
+ | State | RawPtr _ ->
(* This type is opaque: we can't get there *)
raise (Failure "Unreachable")
| Result ->
@@ -412,31 +496,13 @@ let adt_g_value_to_string (fmt : value_formatter)
| [ v ] -> "@Fuel::Succ " ^ v
| _ -> raise (Failure "@Fuel::Succ takes exactly one value")
else raise (Failure "Unreachable: improper variant id for fuel type")
- | Option ->
- let variant_id = Option.get variant_id in
- if variant_id = option_some_id then
- match field_values with
- | [ v ] -> "@Option::Some " ^ v
- | _ -> raise (Failure "Option::Some takes exactly one value")
- else if variant_id = option_none_id then (
- assert (field_values = []);
- "@Option::None")
- else
- raise (Failure "Unreachable: improper variant id for result type")
- | Vec | Array | Slice | Str ->
+ | Array | Slice | Str ->
assert (variant_id = None);
let field_values =
List.mapi (fun i v -> string_of_int i ^ " -> " ^ v) field_values
in
let id = assumed_ty_to_string aty in
- id ^ " [" ^ String.concat "; " field_values ^ "]"
- | Range ->
- assert (variant_id = None);
- let field_values =
- List.mapi (fun i v -> string_of_int i ^ " -> " ^ v) field_values
- in
- let id = assumed_ty_to_string aty in
- id ^ " {" ^ String.concat "; " field_values ^ "}")
+ id ^ " [" ^ String.concat "; " field_values ^ "]")
| _ ->
let fmt = value_to_type_formatter fmt in
raise
@@ -464,10 +530,10 @@ let rec typed_pattern_to_string (fmt : ast_formatter) (v : typed_pattern) :
let fun_sig_to_string (fmt : ast_formatter) (sg : fun_sig) : string =
let ty_fmt = ast_to_type_formatter fmt in
- let type_params = List.map type_var_to_string sg.type_params in
+ let generics = generic_params_to_strings ty_fmt sg.generics in
let inputs = List.map (ty_to_string ty_fmt false) sg.inputs in
let output = ty_to_string ty_fmt false sg.output in
- let all_types = List.concat [ type_params; inputs; [ output ] ] in
+ let all_types = List.concat [ generics; inputs; [ output ] ] in
String.concat " -> " all_types
let inst_fun_sig_to_string (fmt : ast_formatter) (sg : inst_fun_sig) : string =
@@ -495,28 +561,16 @@ let fun_suffix (lp_id : LoopId.id option) (rg_id : T.RegionGroupId.id option) :
let llbc_assumed_fun_id_to_string (fid : A.assumed_fun_id) : string =
match fid with
- | A.Replace -> "core::mem::replace"
- | A.BoxNew -> "alloc::boxed::Box::new"
- | A.BoxDeref -> "core::ops::deref::Deref::deref"
- | A.BoxDerefMut -> "core::ops::deref::DerefMut::deref_mut"
- | A.BoxFree -> "alloc::alloc::box_free"
- | A.VecNew -> "alloc::vec::Vec::new"
- | A.VecPush -> "alloc::vec::Vec::push"
- | A.VecInsert -> "alloc::vec::Vec::insert"
- | A.VecLen -> "alloc::vec::Vec::len"
- | A.VecIndex -> "core::ops::index::Index<alloc::vec::Vec>::index"
- | A.VecIndexMut -> "core::ops::index::IndexMut<alloc::vec::Vec>::index_mut"
+ | BoxNew -> "alloc::boxed::Box::new"
+ | BoxFree -> "alloc::alloc::box_free"
| ArrayIndexShared -> "@ArrayIndexShared"
| ArrayIndexMut -> "@ArrayIndexMut"
| ArrayToSliceShared -> "@ArrayToSliceShared"
| ArrayToSliceMut -> "@ArrayToSliceMut"
- | ArraySubsliceShared -> "@ArraySubsliceShared"
- | ArraySubsliceMut -> "@ArraySubsliceMut"
+ | ArrayRepeat -> "@ArrayRepeat"
| SliceLen -> "@SliceLen"
| SliceIndexShared -> "@SliceIndexShared"
| SliceIndexMut -> "@SliceIndexMut"
- | SliceSubsliceShared -> "@SliceSubsliceShared"
- | SliceSubsliceMut -> "@SliceSubsliceMut"
let pure_assumed_fun_id_to_string (fid : pure_assumed_fun_id) : string =
match fid with
@@ -531,8 +585,11 @@ let regular_fun_id_to_string (fmt : ast_formatter) (fun_id : fun_id) : string =
| FromLlbc (fid, lp_id, rg_id) ->
let f =
match fid with
- | Regular fid -> fmt.fun_decl_id_to_string fid
- | Assumed fid -> llbc_assumed_fun_id_to_string fid
+ | FunId (Regular fid) -> fmt.fun_decl_id_to_string fid
+ | FunId (Assumed fid) -> llbc_assumed_fun_id_to_string fid
+ | TraitMethod (trait_ref, method_name, _) ->
+ let fmt = ast_to_type_formatter fmt in
+ trait_ref_to_string fmt true trait_ref ^ "." ^ method_name
in
f ^ fun_suffix lp_id rg_id
| Pure fid -> pure_assumed_fun_id_to_string fid
@@ -559,9 +616,8 @@ let fun_or_op_id_to_string (fmt : ast_formatter) (fun_id : fun_or_op_id) :
let rec texpression_to_string (fmt : ast_formatter) (inside : bool)
(indent : string) (indent_incr : string) (e : texpression) : string =
match e.e with
- | Var var_id ->
- let s = fmt.var_id_to_string var_id in
- if inside then "(" ^ s ^ ")" else s
+ | Var var_id -> fmt.var_id_to_string var_id
+ | CVar cg_id -> fmt.const_generic_var_id_to_string cg_id
| Const cv -> literal_to_string cv
| App _ ->
(* Recursively destruct the app, to have a pair (app, arguments list) *)
@@ -632,10 +688,11 @@ and app_to_string (fmt : ast_formatter) (inside : bool) (indent : string)
(* There are two possibilities: either the [app] is an instantiated,
* top-level qualifier (function, ADT constructore...), or it is a "regular"
* expression *)
- let app, tys =
+ let app, generics =
match app.e with
| Qualif qualif ->
(* Qualifier case *)
+ let ty_fmt = ast_to_type_formatter fmt in
(* Convert the qualifier identifier *)
let qualif_s =
match qualif.id with
@@ -654,12 +711,17 @@ and app_to_string (fmt : ast_formatter) (inside : bool) (indent : string)
let field_s = adt_field_to_string value_fmt adt_id field_id in
(* Adopting an F*-like syntax *)
ConstStrings.constructor_prefix ^ adt_s ^ "?." ^ field_s
+ | TraitConst (trait_ref, generics, const_name) ->
+ let trait_ref = trait_ref_to_string ty_fmt true trait_ref in
+ let generics_s = generic_args_to_string ty_fmt generics in
+ if generics <> empty_generic_args then
+ "(" ^ trait_ref ^ generics_s ^ ")." ^ const_name
+ else trait_ref ^ "." ^ const_name
in
(* Convert the type instantiation *)
- let ty_fmt = ast_to_type_formatter fmt in
- let tys = List.map (ty_to_string ty_fmt true) qualif.type_args in
+ let generics = generic_args_to_strings ty_fmt true qualif.generics in
(* *)
- (qualif_s, tys)
+ (qualif_s, generics)
| _ ->
(* "Regular" expression case *)
let inside = args <> [] || (args = [] && inside) in
@@ -674,7 +736,7 @@ and app_to_string (fmt : ast_formatter) (inside : bool) (indent : string)
texpression_to_string fmt inside indent1 indent_incr
in
let args = List.map arg_to_string args in
- let all_args = List.append tys args in
+ let all_args = List.append generics args in
(* Put together *)
let e =
if all_args = [] then app else app ^ " " ^ String.concat " " all_args
diff --git a/compiler/Pure.ml b/compiler/Pure.ml
index ac4ca081..e6a3dab5 100644
--- a/compiler/Pure.ml
+++ b/compiler/Pure.ml
@@ -13,6 +13,9 @@ module FieldId = T.FieldId
module SymbolicValueId = V.SymbolicValueId
module FunDeclId = A.FunDeclId
module GlobalDeclId = A.GlobalDeclId
+module TraitDeclId = T.TraitDeclId
+module TraitImplId = T.TraitImplId
+module TraitClauseId = T.TraitClauseId
(** We redefine identifiers for loop: in {!Values}, the identifiers are global
(they monotonically increase across functions) while in {!module:Pure} we want
@@ -21,8 +24,6 @@ module GlobalDeclId = A.GlobalDeclId
module LoopId =
IdGen ()
-type loop_id = LoopId.id [@@deriving show, ord]
-
(** We give an identifier to every phase of the synthesis (forward, backward
for group of regions 0, etc.) *)
module SynthPhaseId =
@@ -37,6 +38,16 @@ module ConstGenericVarId = T.ConstGenericVarId
type integer_type = T.integer_type [@@deriving show, ord]
type const_generic_var = T.const_generic_var [@@deriving show, ord]
type const_generic = T.const_generic [@@deriving show, ord]
+type const_generic_var_id = T.const_generic_var_id [@@deriving show, ord]
+type trait_decl_id = T.trait_decl_id [@@deriving show, ord]
+type trait_impl_id = T.trait_impl_id [@@deriving show, ord]
+type trait_clause_id = T.trait_clause_id [@@deriving show, ord]
+type trait_item_name = T.trait_item_name [@@deriving show, ord]
+type global_decl_id = T.global_decl_id [@@deriving show, ord]
+type fun_decl_id = A.fun_decl_id [@@deriving show, ord]
+type loop_id = LoopId.id [@@deriving show, ord]
+type region_group_id = T.region_group_id [@@deriving show, ord]
+type mutability = Mut | Const [@@deriving show, ord]
(** The assumed types for the pure AST.
@@ -59,12 +70,17 @@ type assumed_ty =
| Result
| Error
| Fuel
- | Vec
- | Option
| Array
| Slice
| Str
- | Range
+ | RawPtr of mutability
+ (** The bool
+ Raw pointers don't make sense in the pure world, but we don't know
+ how to translate them yet and we have to handle some functions which
+ use raw pointers in their signature (for instance some trait declarations
+ for the slices). For now, we use a dedicated type to "mark" the raw pointers,
+ and make sure that those functions are actually not used in the translation.
+ *)
[@@deriving show, ord]
(* TODO: we should never directly manipulate [Return] and [Fail], but rather
@@ -176,6 +192,14 @@ class ['self] iter_ty_base =
inherit! [_] T.iter_const_generic
inherit! [_] PV.iter_literal_type
method visit_type_var_id : 'env -> type_var_id -> unit = fun _ _ -> ()
+ method visit_trait_decl_id : 'env -> trait_decl_id -> unit = fun _ _ -> ()
+ method visit_trait_impl_id : 'env -> trait_impl_id -> unit = fun _ _ -> ()
+
+ method visit_trait_clause_id : 'env -> trait_clause_id -> unit =
+ fun _ _ -> ()
+
+ method visit_trait_item_name : 'env -> trait_item_name -> unit =
+ fun _ _ -> ()
end
(** Ancestor for map visitor for [ty] *)
@@ -185,6 +209,18 @@ class ['self] map_ty_base =
inherit! [_] T.map_const_generic
inherit! [_] PV.map_literal_type
method visit_type_var_id : 'env -> type_var_id -> type_var_id = fun _ x -> x
+
+ method visit_trait_decl_id : 'env -> trait_decl_id -> trait_decl_id =
+ fun _ x -> x
+
+ method visit_trait_impl_id : 'env -> trait_impl_id -> trait_impl_id =
+ fun _ x -> x
+
+ method visit_trait_clause_id : 'env -> trait_clause_id -> trait_clause_id =
+ fun _ x -> x
+
+ method visit_trait_item_name : 'env -> trait_item_name -> trait_item_name =
+ fun _ x -> x
end
(** Ancestor for reduce visitor for [ty] *)
@@ -194,6 +230,18 @@ class virtual ['self] reduce_ty_base =
inherit! [_] T.reduce_const_generic
inherit! [_] PV.reduce_literal_type
method visit_type_var_id : 'env -> type_var_id -> 'a = fun _ _ -> self#zero
+
+ method visit_trait_decl_id : 'env -> trait_decl_id -> 'a =
+ fun _ _ -> self#zero
+
+ method visit_trait_impl_id : 'env -> trait_impl_id -> 'a =
+ fun _ _ -> self#zero
+
+ method visit_trait_clause_id : 'env -> trait_clause_id -> 'a =
+ fun _ _ -> self#zero
+
+ method visit_trait_item_name : 'env -> trait_item_name -> 'a =
+ fun _ _ -> self#zero
end
(** Ancestor for mapreduce visitor for [ty] *)
@@ -205,10 +253,24 @@ class virtual ['self] mapreduce_ty_base =
method visit_type_var_id : 'env -> type_var_id -> type_var_id * 'a =
fun _ x -> (x, self#zero)
+
+ method visit_trait_decl_id : 'env -> trait_decl_id -> trait_decl_id * 'a =
+ fun _ x -> (x, self#zero)
+
+ method visit_trait_impl_id : 'env -> trait_impl_id -> trait_impl_id * 'a =
+ fun _ x -> (x, self#zero)
+
+ method visit_trait_clause_id
+ : 'env -> trait_clause_id -> trait_clause_id * 'a =
+ fun _ x -> (x, self#zero)
+
+ method visit_trait_item_name
+ : 'env -> trait_item_name -> trait_item_name * 'a =
+ fun _ x -> (x, self#zero)
end
type ty =
- | Adt of type_id * ty list * const_generic list
+ | Adt of type_id * generic_args
(** {!Adt} encodes ADTs and tuples and assumed types.
TODO: what about the ended regions? (ADTs may be parameterized
@@ -219,8 +281,38 @@ type ty =
| TypeVar of type_var_id
| Literal of literal_type
| Arrow of ty * ty
+ | TraitType of trait_ref * generic_args * string
+ (** The string is for the name of the associated type *)
+
+and trait_ref = {
+ trait_id : trait_instance_id;
+ generics : generic_args;
+ trait_decl_ref : trait_decl_ref;
+}
+
+and trait_decl_ref = {
+ trait_decl_id : trait_decl_id;
+ decl_generics : generic_args; (* The name: annoying field collisions... *)
+}
+
+and generic_args = {
+ types : ty list;
+ const_generics : const_generic list;
+ trait_refs : trait_ref list;
+}
+
+and trait_instance_id =
+ | Self
+ | TraitImpl of trait_impl_id
+ | Clause of trait_clause_id
+ | ParentClause of trait_instance_id * trait_decl_id * trait_clause_id
+ | ItemClause of
+ trait_instance_id * trait_decl_id * trait_item_name * trait_clause_id
+ | TraitRef of trait_ref
+ | UnknownTrait of string
[@@deriving
show,
+ ord,
visitors
{
name = "iter_ty";
@@ -264,12 +356,37 @@ type type_decl_kind = Struct of field list | Enum of variant list | Opaque
type type_var = T.type_var [@@deriving show]
+type trait_clause = {
+ clause_id : trait_clause_id;
+ trait_id : trait_decl_id;
+ generics : generic_args;
+}
+[@@deriving show]
+
+type generic_params = {
+ types : type_var list;
+ const_generics : const_generic_var list;
+ trait_clauses : trait_clause list;
+}
+[@@deriving show]
+
+type trait_type_constraint = {
+ trait_ref : trait_ref;
+ generics : generic_args;
+ type_name : trait_item_name;
+ ty : ty;
+}
+[@@deriving show, ord]
+
+type predicates = { trait_type_constraints : trait_type_constraint list }
+[@@deriving show]
+
type type_decl = {
def_id : TypeDeclId.id;
name : name;
- type_params : type_var list;
- const_generic_params : const_generic_var list;
+ generics : generic_params;
kind : type_decl_kind;
+ preds : predicates;
}
[@@deriving show]
@@ -420,8 +537,15 @@ type pure_assumed_fun_id =
| FuelEqZero (** Test if some fuel is equal to 0 - TODO: ugly *)
[@@deriving show, ord]
+type fun_id_or_trait_method_ref =
+ | FunId of A.fun_id
+ | TraitMethod of trait_ref * string * fun_decl_id
+ (** The fun decl id is not really needed and here for convenience purposes *)
+[@@deriving show, ord]
+
(** A function id for a non-assumed function *)
-type regular_fun_id = A.fun_id * LoopId.id option * T.RegionGroupId.id option
+type regular_fun_id =
+ fun_id_or_trait_method_ref * LoopId.id option * T.RegionGroupId.id option
[@@deriving show, ord]
(** A function identifier *)
@@ -457,23 +581,20 @@ type projection = { adt_id : type_id; field_id : FieldId.id } [@@deriving show]
type qualif_id =
| FunOrOp of fun_or_op_id (** A function or an operation *)
- | Global of GlobalDeclId.id
+ | Global of global_decl_id
| AdtCons of adt_cons_id (** A function or ADT constructor identifier *)
| Proj of projection (** Field projector *)
+ | TraitConst of trait_ref * generic_args * string
+ (** A trait associated constant *)
[@@deriving show]
-(** An instantiated qualified.
+(** An instantiated qualifier.
Note that for now we have a clear separation between types and expressions,
- which explains why we have the [type_params] field: a function or ADT
+ which explains why we have the [generics] field: a function or ADT
constructor is always fully instantiated.
*)
-type qualif = {
- id : qualif_id;
- type_args : ty list;
- const_generic_args : const_generic list;
-}
-[@@deriving show]
+type qualif = { id : qualif_id; generics : generic_args } [@@deriving show]
type field_id = FieldId.id [@@deriving show, ord]
type var_id = VarId.id [@@deriving show, ord]
@@ -536,6 +657,7 @@ class virtual ['self] mapreduce_expression_base =
*)
type expression =
| Var of var_id (** a variable *)
+ | CVar of const_generic_var_id (** a const generic var *)
| Const of literal
| App of texpression * texpression
(** Application of a function to an argument.
@@ -787,11 +909,11 @@ type fun_sig_info = {
- etc.
*)
type fun_sig = {
- type_params : type_var list;
- const_generic_params : const_generic_var list;
+ generics : generic_params;
(** TODO: we should analyse the signature to make the type parameters implicit whenever possible *)
+ preds : predicates;
inputs : ty list;
- (** The input types.
+ (** The types of the inputs.
Note that those input types take into account the [fuel] parameter,
if the function uses fuel for termination, and the [state] parameter,
@@ -861,8 +983,11 @@ type fun_body = {
}
[@@deriving show]
+type fun_kind = A.fun_kind [@@deriving show]
+
type fun_decl = {
def_id : FunDeclId.id;
+ kind : fun_kind;
num_loops : int;
(** The number of loops in the parent forward function (basically the number
of loops appearing in the original Rust functions, unless some loops are
@@ -882,3 +1007,30 @@ type fun_decl = {
body : fun_body option;
}
[@@deriving show]
+
+type trait_decl = {
+ def_id : trait_decl_id;
+ name : name;
+ generics : generic_params;
+ preds : predicates;
+ parent_clauses : trait_clause list;
+ consts : (trait_item_name * (ty * global_decl_id option)) list;
+ types : (trait_item_name * (trait_clause list * ty option)) list;
+ required_methods : (trait_item_name * fun_decl_id) list;
+ provided_methods : (trait_item_name * fun_decl_id option) list;
+}
+[@@deriving show]
+
+type trait_impl = {
+ def_id : trait_impl_id;
+ name : name;
+ impl_trait : trait_decl_ref;
+ generics : generic_params;
+ preds : predicates;
+ parent_trait_refs : trait_ref list;
+ consts : (trait_item_name * (ty * global_decl_id)) list;
+ types : (trait_item_name * (trait_ref list * ty)) list;
+ required_methods : (trait_item_name * fun_decl_id) list;
+ provided_methods : (trait_item_name * fun_decl_id) list;
+}
+[@@deriving show]
diff --git a/compiler/PureMicroPasses.ml b/compiler/PureMicroPasses.ml
index b6025df4..f3e6cbe2 100644
--- a/compiler/PureMicroPasses.ml
+++ b/compiler/PureMicroPasses.ml
@@ -376,8 +376,7 @@ let compute_pretty_names (def : fun_decl) : fun_decl =
let ty = e.ty in
let ctx, e =
match e.e with
- | Var _ -> (* Nothing to do *) (ctx, e.e)
- | Const _ -> (* Nothing to do *) (ctx, e.e)
+ | Var _ | CVar _ | Const _ -> (* Nothing to do *) (ctx, e.e)
| App (app, arg) ->
let ctx, app = update_texpression app ctx in
let ctx, arg = update_texpression arg ctx in
@@ -584,13 +583,10 @@ let intro_struct_updates (ctx : trans_ctx) (def : fun_decl) : fun_decl =
| Qualif
{
id = AdtCons { adt_id = AdtId adt_id; variant_id = None };
- type_args = _;
- const_generic_args = _;
+ generics = _;
} ->
(* Lookup the def *)
- let decl =
- TypeDeclId.Map.find adt_id ctx.type_context.type_decls
- in
+ let decl = TypeDeclId.Map.find adt_id ctx.type_ctx.type_decls in
(* Check that there are as many arguments as there are fields - note
that the def should have a body (otherwise we couldn't use the
constructor) *)
@@ -599,8 +595,7 @@ let intro_struct_updates (ctx : trans_ctx) (def : fun_decl) : fun_decl =
(* Check if the definition is recursive *)
let is_rec =
match
- TypeDeclId.Map.find adt_id
- ctx.type_context.type_decls_groups
+ TypeDeclId.Map.find adt_id ctx.type_ctx.type_decls_groups
with
| NonRec _ -> false
| Rec _ -> true
@@ -682,8 +677,8 @@ let inline_useless_var_reassignments (inline_named : bool) (inline_pure : bool)
| _ -> false
in
(* And either:
- * 2.1 the right-expression is a variable or a global *)
- let var_or_global = is_var re || is_global re in
+ * 2.1 the right-expression is a variable, a global or a const generic var *)
+ let var_or_global = is_var re || is_cvar re || is_global re in
(* Or:
* 2.2 the right-expression is a constant value, an ADT value,
* a projection or a primitive function call *and* the flag
@@ -767,10 +762,10 @@ let inline_useless_var_reassignments (inline_named : bool) (inline_pure : bool)
In this situation, we can remove the call [f@fwd x].
*)
let expression_contains_child_call_in_all_paths (ctx : trans_ctx)
- (id0 : A.fun_id) (lp_id0 : LoopId.id option)
- (rg_id0 : T.RegionGroupId.id option) (tys0 : ty list)
+ (id0 : fun_id_or_trait_method_ref) (lp_id0 : LoopId.id option)
+ (rg_id0 : T.RegionGroupId.id option) (generics0 : generic_args)
(args0 : texpression list) (e : texpression) : bool =
- let check_call (fun_id1 : fun_or_op_id) (tys1 : ty list)
+ let check_call (fun_id1 : fun_or_op_id) (generics1 : generic_args)
(args1 : texpression list) : bool =
(* Check the fun_ids, to see if call1's function is a child of call0's function *)
match fun_id1 with
@@ -793,7 +788,12 @@ let expression_contains_child_call_in_all_paths (ctx : trans_ctx)
(* We need to use the regions hierarchy *)
(* First, lookup the signature of the LLBC function *)
let sg =
- LlbcAstUtils.lookup_fun_sig id0 ctx.fun_context.fun_decls
+ let id0 =
+ match id0 with
+ | FunId fun_id -> fun_id
+ | TraitMethod (_, _, fun_decl_id) -> Regular fun_decl_id
+ in
+ LlbcAstUtils.lookup_fun_sig id0 ctx.fun_ctx.fun_decls
in
(* Compute the set of ancestors of the function in call1 *)
let call1_ancestors =
@@ -817,8 +817,8 @@ let expression_contains_child_call_in_all_paths (ctx : trans_ctx)
let input_eq (v0, v1) =
PureUtils.remove_meta v0 = PureUtils.remove_meta v1
in
- (* Compare the input types and the prefix of the input arguments *)
- tys0 = tys1 && List.for_all input_eq args
+ (* Compare the generics and the prefix of the input arguments *)
+ generics0 = generics1 && List.for_all input_eq args
else (* Not a child *)
false
else (* Not the same function *)
@@ -834,7 +834,7 @@ let expression_contains_child_call_in_all_paths (ctx : trans_ctx)
method! visit_texpression env e =
match e.e with
- | Var _ | Const _ -> fun _ -> false
+ | Var _ | CVar _ | Const _ -> fun _ -> false
| StructUpdate _ ->
(* There shouldn't be monadic calls in structure updates - also
note that by returning [false] we are conservative: we might
@@ -844,8 +844,8 @@ let expression_contains_child_call_in_all_paths (ctx : trans_ctx)
| Let (_, _, re, e) -> (
match opt_destruct_function_call re with
| None -> fun () -> self#visit_texpression env e ()
- | Some (func1, tys1, args1) ->
- let call_is_child = check_call func1 tys1 args1 in
+ | Some (func1, generics1, args1) ->
+ let call_is_child = check_call func1 generics1 args1 in
if call_is_child then fun () -> true
else fun () -> self#visit_texpression env e ())
| App _ -> (
@@ -930,7 +930,7 @@ let filter_useless (filter_monadic_calls : bool) (ctx : trans_ctx)
method! visit_expression env e =
match e with
- | Var _ | Const _ | App _ | Qualif _
+ | Var _ | CVar _ | Const _ | App _ | Qualif _
| Switch (_, _)
| Meta (_, _)
| StructUpdate _ | Abs _ ->
@@ -1086,13 +1086,12 @@ let simplify_aggregates (ctx : trans_ctx) (def : fun_decl) : fun_decl =
| Qualif
{
id = AdtCons { adt_id = AdtId adt_id; variant_id = None };
- type_args;
- const_generic_args;
+ generics;
} ->
(* This is a struct *)
(* Retrieve the definiton, to find how many fields there are *)
let adt_decl =
- TypeDeclId.Map.find adt_id ctx.type_context.type_decls
+ TypeDeclId.Map.find adt_id ctx.type_ctx.type_decls
in
let fields =
match adt_decl.kind with
@@ -1108,7 +1107,7 @@ let simplify_aggregates (ctx : trans_ctx) (def : fun_decl) : fun_decl =
* [x.field] for some variable [x], and where the projection
* is for the proper ADT *)
let to_var_proj (i : int) (arg : texpression) :
- (ty list * const_generic list * var_id) option =
+ (generic_args * var_id) option =
match arg.e with
| App (proj, x) -> (
match (proj.e, x.e) with
@@ -1116,16 +1115,14 @@ let simplify_aggregates (ctx : trans_ctx) (def : fun_decl) : fun_decl =
{
id =
Proj { adt_id = AdtId proj_adt_id; field_id };
- type_args = proj_type_args;
- const_generic_args = proj_const_generic_args;
+ generics = proj_generics;
},
Var v ) ->
(* We check that this is the proper ADT, and the proper field *)
if
proj_adt_id = adt_id
&& FieldId.to_int field_id = i
- then
- Some (proj_type_args, proj_const_generic_args, v)
+ then Some (proj_generics, v)
else None
| _ -> None)
| _ -> None
@@ -1136,14 +1133,13 @@ let simplify_aggregates (ctx : trans_ctx) (def : fun_decl) : fun_decl =
if List.length args = num_fields then
(* Check that this is the same variable we project from -
* note that we checked above that there is at least one field *)
- let (_, _, x), end_args = Collections.List.pop args in
- if List.for_all (fun (_, _, y) -> y = x) end_args then (
+ let (_, x), end_args = Collections.List.pop args in
+ if List.for_all (fun (_, y) -> y = x) end_args then (
(* We can substitute *)
(* Sanity check: all types correct *)
assert (
List.for_all
- (fun (tys, cgs, _) ->
- tys = type_args && cgs = const_generic_args)
+ (fun (generics1, _) -> generics1 = generics)
args);
{ e with e = Var x })
else super#visit_texpression env e
@@ -1162,8 +1158,7 @@ let simplify_aggregates (ctx : trans_ctx) (def : fun_decl) : fun_decl =
| ( Qualif
{
id = Proj { adt_id = AdtId proj_adt_id; field_id };
- type_args = _;
- const_generic_args = _;
+ generics = _;
},
Var v ) ->
(* We check that this is the proper ADT, and the proper field *)
@@ -1361,8 +1356,8 @@ let decompose_loops (def : fun_decl) : fun_decl * fun_decl list =
let loop_sig =
{
- type_params = fun_sig.type_params;
- const_generic_params = fun_sig.const_generic_params;
+ generics = fun_sig.generics;
+ preds = fun_sig.preds;
inputs = inputs_tys;
output;
doutputs;
@@ -1427,6 +1422,7 @@ let decompose_loops (def : fun_decl) : fun_decl * fun_decl list =
let loop_def =
{
def_id = def.def_id;
+ kind = def.kind;
num_loops;
loop_id = Some loop.loop_id;
back_id = def.back_id;
@@ -1466,13 +1462,12 @@ let decompose_loops (def : fun_decl) : fun_decl * fun_decl list =
In such situation, we can remove the forward function definition
altogether.
*)
-let keep_forward (trans : pure_fun_translation) : bool =
- let (fwd, _), backs = trans in
+let keep_forward (fwd : fun_and_loops) (backs : fun_and_loops list) : bool =
(* Note that at this point, the output types are no longer seen as tuples:
* they should be lists of length 1. *)
if
!Config.filter_useless_functions
- && fwd.signature.output = mk_result_ty mk_unit_ty
+ && fwd.f.signature.output = mk_result_ty mk_unit_ty
&& backs <> []
then false
else true
@@ -1518,7 +1513,7 @@ let unit_vars_to_unit (def : fun_decl) : fun_decl =
function calls, and when translating end abstractions. Here, we can do
something simpler, in one micro-pass.
*)
-let eliminate_box_functions (_ctx : trans_ctx) (def : fun_decl) : fun_decl =
+let eliminate_box_functions (ctx : trans_ctx) (def : fun_decl) : fun_decl =
(* The map visitor *)
let obj =
object
@@ -1527,30 +1522,44 @@ let eliminate_box_functions (_ctx : trans_ctx) (def : fun_decl) : fun_decl =
method! visit_texpression env e =
match opt_destruct_function_call e with
| Some (fun_id, _tys, args) -> (
+ (* Below, when dealing with the arguments: we consider the very
+ * general case, where functions could be boxed (meaning we
+ * could have: [box_new f x])
+ * *)
match fun_id with
- | Fun (FromLlbc (A.Assumed aid, _lp_id, rg_id)) -> (
- (* Below, when dealing with the arguments: we consider the very
- * general case, where functions could be boxed (meaning we
- * could have: [box_new f x])
- * *)
+ | Fun (FromLlbc (FunId (Assumed aid), _lp_id, rg_id)) -> (
match (aid, rg_id) with
- | A.BoxNew, _ ->
+ | BoxNew, _ ->
assert (rg_id = None);
let arg, args = Collections.List.pop args in
mk_apps arg args
- | A.BoxDeref, None ->
+ | BoxFree, _ ->
+ assert (args = []);
+ mk_unit_rvalue
+ | ( ( SliceIndexShared | SliceIndexMut | ArrayIndexShared
+ | ArrayIndexMut | ArrayToSliceShared | ArrayToSliceMut
+ | ArrayRepeat | SliceLen ),
+ _ ) ->
+ super#visit_texpression env e)
+ | Fun (FromLlbc (FunId (Regular fid), _lp_id, rg_id)) -> (
+ (* Lookup the function name *)
+ let def = FunDeclId.Map.find fid ctx.fun_ctx.fun_decls in
+ match
+ (Names.name_no_disambiguators_to_string def.name, rg_id)
+ with
+ | "alloc::boxed::Box::deref", None ->
(* [Box::deref] forward is the identity *)
let arg, args = Collections.List.pop args in
mk_apps arg args
- | A.BoxDeref, Some _ ->
+ | "alloc::boxed::Box::deref", Some _ ->
(* [Box::deref] backward is [()] (doesn't give back anything) *)
assert (args = []);
mk_unit_rvalue
- | A.BoxDerefMut, None ->
+ | "alloc::boxed::Box::deref_mut", None ->
(* [Box::deref_mut] forward is the identity *)
let arg, args = Collections.List.pop args in
mk_apps arg args
- | A.BoxDerefMut, Some _ ->
+ | "alloc::boxed::Box::deref_mut", Some _ ->
(* [Box::deref_mut] back is almost the identity:
* let box_deref_mut (x_init : t) (x_back : t) : t = x_back
* *)
@@ -1560,17 +1569,7 @@ let eliminate_box_functions (_ctx : trans_ctx) (def : fun_decl) : fun_decl =
| _ -> raise (Failure "Unreachable")
in
mk_apps arg args
- | A.BoxFree, _ ->
- assert (args = []);
- mk_unit_rvalue
- | ( ( A.Replace | VecNew | VecPush | VecInsert | VecLen
- | VecIndex | VecIndexMut | ArraySubsliceShared
- | ArraySubsliceMut | SliceIndexShared | SliceIndexMut
- | SliceSubsliceShared | SliceSubsliceMut | ArrayIndexShared
- | ArrayIndexMut | ArrayToSliceShared | ArrayToSliceMut
- | SliceLen ),
- _ ) ->
- super#visit_texpression env e)
+ | _ -> super#visit_texpression env e)
| _ -> super#visit_texpression env e)
| _ -> super#visit_texpression env e
end
@@ -1914,7 +1913,7 @@ let apply_end_passes_to_def (ctx : trans_ctx) (def : fun_decl) : fun_decl =
[ctx]: used only for printing.
*)
let apply_passes_to_def (ctx : trans_ctx) (def : fun_decl) :
- (fun_decl * fun_decl list) option =
+ fun_and_loops option =
(* Debug *)
log#ldebug
(lazy
@@ -1955,9 +1954,9 @@ let apply_passes_to_def (ctx : trans_ctx) (def : fun_decl) :
let def, loops = decompose_loops def in
(* Apply the remaining passes *)
- let def = apply_end_passes_to_def ctx def in
+ let f = apply_end_passes_to_def ctx def in
let loops = List.map (apply_end_passes_to_def ctx) loops in
- Some (def, loops)
+ Some { f; loops }
(** Small utility for {!filter_loop_inputs} *)
let filter_prefix (keep : bool list) (ls : 'a list) : 'a list =
@@ -1983,8 +1982,8 @@ end
module FunLoopIdMap = Collections.MakeMap (FunLoopIdOrderedType)
(** Filter the useless loop input parameters. *)
-let filter_loop_inputs (transl : (bool * pure_fun_translation) list) :
- (bool * pure_fun_translation) list =
+let filter_loop_inputs (transl : pure_fun_translation list) :
+ pure_fun_translation list =
(* We need to explore groups of mutually recursive functions. In order
to compute which parameters are useless, we need to explore the
functions by groups of mutually recursive definitions.
@@ -2002,10 +2001,11 @@ let filter_loop_inputs (transl : (bool * pure_fun_translation) list) :
(List.concat
(List.concat
(List.map
- (fun (_, ((fwd, loops_fwd), backs)) ->
- [ fwd :: loops_fwd ]
+ (fun { fwd; backs; _ } ->
+ [ fwd.f :: fwd.loops ]
:: List.map
- (fun (back, loops_back) -> [ back :: loops_back ])
+ (fun { f = back; loops = loops_back } ->
+ [ back :: loops_back ])
backs)
transl)))
in
@@ -2030,7 +2030,6 @@ let filter_loop_inputs (transl : (bool * pure_fun_translation) list) :
additional parameters.
*)
let used_map = ref FunLoopIdMap.empty in
- let fun_id_to_fun_loop_id (fid, loop_id, _) = (fid, loop_id) in
(* We start by computing the filtering information, for each function *)
let compute_one_filter_info (decl : fun_decl) =
@@ -2051,7 +2050,7 @@ let filter_loop_inputs (transl : (bool * pure_fun_translation) list) :
let inputs_set = VarId.Set.of_list (List.map var_get_id inputs_prefix) in
assert (Option.is_some decl.loop_id);
- let fun_id = (A.Regular decl.def_id, decl.loop_id) in
+ let fun_id = (E.Regular decl.def_id, decl.loop_id) in
let set_used vid =
used := List.map (fun (vid', b) -> (vid', b || vid = vid')) !used
@@ -2075,8 +2074,8 @@ let filter_loop_inputs (transl : (bool * pure_fun_translation) list) :
match e_app.e with
| Qualif qualif -> (
match qualif.id with
- | FunOrOp (Fun (FromLlbc fun_id')) ->
- if fun_id_to_fun_loop_id fun_id' = fun_id then (
+ | FunOrOp (Fun (FromLlbc (FunId fun_id', loop_id', _))) ->
+ if (fun_id', loop_id') = fun_id then (
(* For each argument, check if it is exactly the original
input parameter. Note that there shouldn't be partial
applications of loop functions: the number of arguments
@@ -2135,22 +2134,15 @@ let filter_loop_inputs (transl : (bool * pure_fun_translation) list) :
(* We then apply the filtering to all the function definitions at once *)
let filter_in_one (decl : fun_decl) : fun_decl =
(* Filter the function signature *)
- let fun_id = (A.Regular decl.def_id, decl.loop_id, decl.back_id) in
+ let fun_id = (E.Regular decl.def_id, decl.loop_id) in
let decl =
- match FunLoopIdMap.find_opt (fun_id_to_fun_loop_id fun_id) !used_map with
+ match FunLoopIdMap.find_opt fun_id !used_map with
| None -> (* Nothing to filter *) decl
| Some used_info ->
let num_filtered =
List.length (List.filter (fun b -> not b) used_info)
in
- let {
- type_params;
- const_generic_params;
- inputs;
- output;
- doutputs;
- info;
- } =
+ let { generics; preds; inputs; output; doutputs; info } =
decl.signature
in
let {
@@ -2178,16 +2170,7 @@ let filter_loop_inputs (transl : (bool * pure_fun_translation) list) :
effect_info;
}
in
- let signature =
- {
- type_params;
- const_generic_params;
- inputs;
- output;
- doutputs;
- info;
- }
- in
+ let signature = { generics; preds; inputs; output; doutputs; info } in
{ decl with signature }
in
@@ -2201,9 +2184,7 @@ let filter_loop_inputs (transl : (bool * pure_fun_translation) list) :
let { inputs; inputs_lvs; body } = body in
let inputs, inputs_lvs =
- match
- FunLoopIdMap.find_opt (fun_id_to_fun_loop_id fun_id) !used_map
- with
+ match FunLoopIdMap.find_opt fun_id !used_map with
| None -> (* Nothing to filter *) (inputs, inputs_lvs)
| Some used_info ->
let inputs = filter_prefix used_info inputs in
@@ -2223,11 +2204,10 @@ let filter_loop_inputs (transl : (bool * pure_fun_translation) list) :
match e_app.e with
| Qualif qualif -> (
match qualif.id with
- | FunOrOp (Fun (FromLlbc fun_id)) -> (
+ | FunOrOp (Fun (FromLlbc (FunId fun_id, loop_id, _)))
+ -> (
match
- FunLoopIdMap.find_opt
- (fun_id_to_fun_loop_id fun_id)
- !used_map
+ FunLoopIdMap.find_opt (fun_id, loop_id) !used_map
with
| None -> super#visit_texpression env e
| Some used_info ->
@@ -2267,13 +2247,13 @@ let filter_loop_inputs (transl : (bool * pure_fun_translation) list) :
in
let transl =
List.map
- (fun (b, (fwd, backs)) ->
- let filter_fun_and_loops (f, fl) =
- (filter_in_one f, List.map filter_in_one fl)
+ (fun trans ->
+ let filter_fun_and_loops f =
+ { f = filter_in_one f.f; loops = List.map filter_in_one f.loops }
in
- let fwd = filter_fun_and_loops fwd in
- let backs = List.map filter_fun_and_loops backs in
- (b, (fwd, backs)))
+ let fwd = filter_fun_and_loops trans.fwd in
+ let backs = List.map filter_fun_and_loops trans.backs in
+ { trans with fwd; backs })
transl
in
@@ -2294,18 +2274,17 @@ let filter_loop_inputs (transl : (bool * pure_fun_translation) list) :
but convenient.
*)
let apply_passes_to_pure_fun_translations (ctx : trans_ctx)
- (transl : (fun_decl * fun_decl list) list) :
- (bool * pure_fun_translation) list =
- let apply_to_one (trans : fun_decl * fun_decl list) :
- bool * pure_fun_translation =
+ (transl : (fun_decl * fun_decl list) list) : pure_fun_translation list =
+ let apply_to_one (trans : fun_decl * fun_decl list) : pure_fun_translation =
(* Apply the passes to the individual functions *)
- let forward, backwards = trans in
- let forward = Option.get (apply_passes_to_def ctx forward) in
- let backwards = List.filter_map (apply_passes_to_def ctx) backwards in
- let trans = (forward, backwards) in
+ let fwd, backs = trans in
+ let fwd = Option.get (apply_passes_to_def ctx fwd) in
+ let backs = List.filter_map (apply_passes_to_def ctx) backs in
(* Compute whether we need to filter the forward function or not *)
- (keep_forward trans, trans)
+ let keep_fwd = keep_forward fwd backs in
+ { keep_fwd; fwd; backs }
in
+
let transl = List.map apply_to_one transl in
(* Filter the useless inputs in the loop functions *)
diff --git a/compiler/PureTypeCheck.ml b/compiler/PureTypeCheck.ml
index 8d28bb8a..2ad942bb 100644
--- a/compiler/PureTypeCheck.ml
+++ b/compiler/PureTypeCheck.ml
@@ -9,17 +9,19 @@ open PureUtils
of fields is fixed: it shouldn't be used for arrays, slices, etc.
*)
let get_adt_field_types (type_decls : type_decl TypeDeclId.Map.t)
- (type_id : type_id) (variant_id : VariantId.id option) (tys : ty list)
- (cgs : const_generic list) : ty list =
+ (type_id : type_id) (variant_id : VariantId.id option)
+ (generics : generic_args) : ty list =
match type_id with
| Tuple ->
(* Tuple *)
+ assert (generics.const_generics = []);
+ assert (generics.trait_refs = []);
assert (variant_id = None);
- tys
+ generics.types
| AdtId def_id ->
(* "Regular" ADT *)
let def = TypeDeclId.Map.find def_id type_decls in
- type_decl_get_instantiated_fields_types def variant_id tys cgs
+ type_decl_get_instantiated_fields_types def variant_id generics
| Assumed aty -> (
(* Assumed type *)
match aty with
@@ -27,14 +29,14 @@ let get_adt_field_types (type_decls : type_decl TypeDeclId.Map.t)
(* This type is opaque *)
raise (Failure "Unreachable: opaque type")
| Result ->
- let ty = Collections.List.to_cons_nil tys in
+ let ty = Collections.List.to_cons_nil generics.types in
let variant_id = Option.get variant_id in
if variant_id = result_return_id then [ ty ]
else if variant_id = result_fail_id then [ mk_error_ty ]
else
raise (Failure "Unreachable: improper variant id for result type")
| Error ->
- assert (tys = []);
+ assert (generics = empty_generic_args);
let variant_id = Option.get variant_id in
assert (
variant_id = error_failure_id || variant_id = error_out_of_fuel_id);
@@ -44,18 +46,7 @@ let get_adt_field_types (type_decls : type_decl TypeDeclId.Map.t)
if variant_id = fuel_zero_id then []
else if variant_id = fuel_succ_id then [ mk_fuel_ty ]
else raise (Failure "Unreachable: improper variant id for fuel type")
- | Option ->
- let ty = Collections.List.to_cons_nil tys in
- let variant_id = Option.get variant_id in
- if variant_id = option_some_id then [ ty ]
- else if variant_id = option_none_id then []
- else
- raise (Failure "Unreachable: improper variant id for option type")
- | Range ->
- let ty = Collections.List.to_cons_nil tys in
- assert (variant_id = None);
- [ ty; ty ]
- | Vec | Array | Slice | Str ->
+ | Array | Slice | Str | RawPtr _ ->
(* Array: when not symbolic values (for instance, because of aggregates),
the array expressions are introduced as struct updates *)
raise (Failure "Attempting to access the fields of an opaque type"))
@@ -65,6 +56,9 @@ type tc_ctx = {
global_decls : A.global_decl A.GlobalDeclId.Map.t;
(** The global declarations *)
env : ty VarId.Map.t; (** Environment from variables to types *)
+ const_generics : ty T.ConstGenericVarId.Map.t;
+ (** The types of the const generics *)
+ (* TODO: add trait type constraints *)
}
let check_literal (v : literal) (ty : literal_type) : unit =
@@ -86,12 +80,13 @@ let rec check_typed_pattern (ctx : tc_ctx) (v : typed_pattern) : tc_ctx =
{ ctx with env }
| PatAdt av ->
(* Compute the field types *)
- let type_id, tys, cgs = ty_as_adt v.ty in
+ let type_id, generics = ty_as_adt v.ty in
let field_tys =
- get_adt_field_types ctx.type_decls type_id av.variant_id tys cgs
+ get_adt_field_types ctx.type_decls type_id av.variant_id generics
in
let check_value (ctx : tc_ctx) (ty : ty) (v : typed_pattern) : tc_ctx =
if ty <> v.ty then (
+ (* TODO: we need to normalize the types *)
log#serror
("check_typed_pattern: not the same types:" ^ "\n- ty: "
^ show_ty ty ^ "\n- v.ty: " ^ show_ty v.ty);
@@ -115,6 +110,9 @@ let rec check_texpression (ctx : tc_ctx) (e : texpression) : unit =
match VarId.Map.find_opt var_id ctx.env with
| None -> ()
| Some ty -> assert (ty = e.ty))
+ | CVar cg_id ->
+ let ty = T.ConstGenericVarId.Map.find cg_id ctx.const_generics in
+ assert (ty = e.ty)
| Const cv -> check_literal cv (ty_as_literal e.ty)
| App (app, arg) ->
let input_ty, output_ty = destruct_arrow app.ty in
@@ -133,35 +131,34 @@ let rec check_texpression (ctx : tc_ctx) (e : texpression) : unit =
match qualif.id with
| FunOrOp _ -> () (* TODO *)
| Global _ -> () (* TODO *)
+ | TraitConst _ -> () (* TODO *)
| Proj { adt_id = proj_adt_id; field_id } ->
(* Note we can only project fields of structures (not enumerations) *)
(* Deconstruct the projector type *)
let adt_ty, field_ty = destruct_arrow e.ty in
- let adt_id, adt_type_args, adt_cg_args = ty_as_adt adt_ty in
+ let adt_id, adt_generics = ty_as_adt adt_ty in
(* Check the ADT type *)
assert (adt_id = proj_adt_id);
- assert (adt_type_args = qualif.type_args);
- assert (adt_cg_args = qualif.const_generic_args);
+ assert (adt_generics = qualif.generics);
(* Retrieve and check the expected field type *)
let variant_id = None in
let expected_field_tys =
get_adt_field_types ctx.type_decls proj_adt_id variant_id
- qualif.type_args qualif.const_generic_args
+ qualif.generics
in
let expected_field_ty = FieldId.nth expected_field_tys field_id in
assert (expected_field_ty = field_ty)
| AdtCons id -> (
let expected_field_tys =
get_adt_field_types ctx.type_decls id.adt_id id.variant_id
- qualif.type_args qualif.const_generic_args
+ qualif.generics
in
let field_tys, adt_ty = destruct_arrows e.ty in
assert (expected_field_tys = field_tys);
match adt_ty with
- | Adt (type_id, tys, cgs) ->
+ | Adt (type_id, generics) ->
assert (type_id = id.adt_id);
- assert (tys = qualif.type_args);
- assert (cgs = qualif.const_generic_args)
+ assert (generics = qualif.generics)
| _ -> raise (Failure "Unreachable")))
| Let (monadic, pat, re, e_next) ->
let expected_pat_ty = if monadic then destruct_result re.ty else re.ty in
@@ -207,15 +204,14 @@ let rec check_texpression (ctx : tc_ctx) (e : texpression) : unit =
| Some ty -> assert (ty = e.ty));
(* Check the fields *)
(* Retrieve and check the expected field type *)
- let adt_id, adt_type_args, adt_cg_args = ty_as_adt e.ty in
+ let adt_id, adt_generics = ty_as_adt e.ty in
assert (adt_id = supd.struct_id);
(* The id can only be: a custom type decl or an array *)
match adt_id with
| AdtId _ ->
let variant_id = None in
let expected_field_tys =
- get_adt_field_types ctx.type_decls adt_id variant_id adt_type_args
- adt_cg_args
+ get_adt_field_types ctx.type_decls adt_id variant_id adt_generics
in
List.iter
(fun (fid, fe) ->
@@ -224,7 +220,9 @@ let rec check_texpression (ctx : tc_ctx) (e : texpression) : unit =
check_texpression ctx fe)
supd.updates
| Assumed Array ->
- let expected_field_ty = Collections.List.to_cons_nil adt_type_args in
+ let expected_field_ty =
+ Collections.List.to_cons_nil adt_generics.types
+ in
List.iter
(fun (_, fe) ->
assert (expected_field_ty = fe.ty);
diff --git a/compiler/PureUtils.ml b/compiler/PureUtils.ml
index 1c8d8921..3aeabffe 100644
--- a/compiler/PureUtils.ml
+++ b/compiler/PureUtils.ml
@@ -89,14 +89,31 @@ let mk_mplace (var_id : E.VarId.id) (name : string option)
(projection : mprojection) : mplace =
{ var_id; name; projection }
+let empty_generic_params : generic_params =
+ { types = []; const_generics = []; trait_clauses = [] }
+
+let empty_generic_args : generic_args =
+ { types = []; const_generics = []; trait_refs = [] }
+
+let mk_generic_args_from_types (types : ty list) : generic_args =
+ { types; const_generics = []; trait_refs = [] }
+
+type subst = {
+ ty_subst : TypeVarId.id -> ty;
+ cg_subst : ConstGenericVarId.id -> const_generic;
+ tr_subst : TraitClauseId.id -> trait_instance_id;
+ tr_self : trait_instance_id;
+}
+
(** Type substitution *)
-let ty_substitute (tsubst : TypeVarId.id -> ty)
- (cgsubst : ConstGenericVarId.id -> const_generic) (ty : ty) : ty =
+let ty_substitute (subst : subst) (ty : ty) : ty =
let obj =
object
inherit [_] map_ty
- method! visit_TypeVar _ var_id = tsubst var_id
- method! visit_ConstGenericVar _ var_id = cgsubst var_id
+ method! visit_TypeVar _ var_id = subst.ty_subst var_id
+ method! visit_ConstGenericVar _ var_id = subst.cg_subst var_id
+ method! visit_Clause _ id = subst.tr_subst id
+ method! visit_Self _ = subst.tr_self
end
in
obj#visit_ty () ty
@@ -115,6 +132,18 @@ let make_const_generic_subst (vars : const_generic_var list)
(cgs : const_generic list) : ConstGenericVarId.id -> const_generic =
Substitute.make_const_generic_subst_from_vars vars cgs
+let make_trait_subst (clauses : trait_clause list) (refs : trait_ref list) :
+ TraitClauseId.id -> trait_instance_id =
+ let clauses = List.map (fun x -> x.clause_id) clauses in
+ let refs = List.map (fun x -> TraitRef x) refs in
+ let ls = List.combine clauses refs in
+ let mp =
+ List.fold_left
+ (fun mp (k, v) -> TraitClauseId.Map.add k v mp)
+ TraitClauseId.Map.empty ls
+ in
+ fun id -> TraitClauseId.Map.find id mp
+
(** Retrieve the list of fields for the given variant of a {!type:Aeneas.Pure.type_decl}.
Raises [Invalid_argument] if the arguments are incorrect.
@@ -135,20 +164,27 @@ let type_decl_get_fields (def : type_decl)
- def: " ^ show_type_decl def ^ "\n- opt_variant_id: "
^ opt_variant_id))
+let make_subst_from_generics (params : generic_params) (args : generic_args)
+ (tr_self : trait_instance_id) : subst =
+ let ty_subst = make_type_subst params.types args.types in
+ let cg_subst =
+ make_const_generic_subst params.const_generics args.const_generics
+ in
+ let tr_subst = make_trait_subst params.trait_clauses args.trait_refs in
+ { ty_subst; cg_subst; tr_subst; tr_self }
+
(** Instantiate the type variables for the chosen variant in an ADT definition,
and return the list of the types of its fields *)
let type_decl_get_instantiated_fields_types (def : type_decl)
- (opt_variant_id : VariantId.id option) (types : ty list)
- (cgs : const_generic list) : ty list =
- let ty_subst = make_type_subst def.type_params types in
- let cg_subst = make_const_generic_subst def.const_generic_params cgs in
+ (opt_variant_id : VariantId.id option) (generics : generic_args) : ty list =
+ (* There shouldn't be any reference to Self *)
+ let tr_self = UnknownTrait __FUNCTION__ in
+ let subst = make_subst_from_generics def.generics generics tr_self in
let fields = type_decl_get_fields def opt_variant_id in
- List.map (fun f -> ty_substitute ty_subst cg_subst f.field_ty) fields
+ List.map (fun f -> ty_substitute subst f.field_ty) fields
-let fun_sig_substitute (tsubst : TypeVarId.id -> ty)
- (cgsubst : ConstGenericVarId.id -> const_generic) (sg : fun_sig) :
- inst_fun_sig =
- let subst = ty_substitute tsubst cgsubst in
+let fun_sig_substitute (subst : subst) (sg : fun_sig) : inst_fun_sig =
+ let subst = ty_substitute subst in
let inputs = List.map subst sg.inputs in
let output = subst sg.output in
let doutputs = List.map subst sg.doutputs in
@@ -164,7 +200,8 @@ let fun_sig_substitute (tsubst : TypeVarId.id -> ty)
*)
let rec let_group_requires_parentheses (e : texpression) : bool =
match e.e with
- | Var _ | Const _ | App _ | Abs _ | Qualif _ | StructUpdate _ -> false
+ | Var _ | CVar _ | Const _ | App _ | Abs _ | Qualif _ | StructUpdate _ ->
+ false
| Let (monadic, _, _, next_e) ->
if monadic then true else let_group_requires_parentheses next_e
| Switch (_, _) -> false
@@ -184,15 +221,18 @@ let is_var (e : texpression) : bool =
let as_var (e : texpression) : VarId.id =
match e.e with Var v -> v | _ -> raise (Failure "Unreachable")
+let is_cvar (e : texpression) : bool =
+ match e.e with CVar _ -> true | _ -> false
+
let is_global (e : texpression) : bool =
match e.e with Qualif { id = Global _; _ } -> true | _ -> false
let is_const (e : texpression) : bool =
match e.e with Const _ -> true | _ -> false
-let ty_as_adt (ty : ty) : type_id * ty list * const_generic list =
+let ty_as_adt (ty : ty) : type_id * generic_args =
match ty with
- | Adt (id, tys, cgs) -> (id, tys, cgs)
+ | Adt (id, generics) -> (id, generics)
| _ -> raise (Failure "Unreachable")
(** Remove the external occurrences of {!Meta} *)
@@ -290,28 +330,30 @@ let destruct_qualif_app (e : texpression) : qualif * texpression list =
(** Destruct an expression into a function call, if possible *)
let opt_destruct_function_call (e : texpression) :
- (fun_or_op_id * ty list * texpression list) option =
+ (fun_or_op_id * generic_args * texpression list) option =
match opt_destruct_qualif_app e with
| None -> None
| Some (qualif, args) -> (
match qualif.id with
- | FunOrOp fun_id -> Some (fun_id, qualif.type_args, args)
+ | FunOrOp fun_id -> Some (fun_id, qualif.generics, args)
| _ -> None)
let opt_destruct_result (ty : ty) : ty option =
match ty with
- | Adt (Assumed Result, tys, cgs) ->
- assert (cgs = []);
- Some (Collections.List.to_cons_nil tys)
+ | Adt (Assumed Result, generics) ->
+ assert (generics.const_generics = []);
+ assert (generics.trait_refs = []);
+ Some (Collections.List.to_cons_nil generics.types)
| _ -> None
let destruct_result (ty : ty) : ty = Option.get (opt_destruct_result ty)
let opt_destruct_tuple (ty : ty) : ty list option =
match ty with
- | Adt (Tuple, tys, cgs) ->
- assert (cgs = []);
- Some tys
+ | Adt (Tuple, generics) ->
+ assert (generics.const_generics = []);
+ assert (generics.trait_refs = []);
+ Some generics.types
| _ -> None
let mk_abs (x : typed_pattern) (e : texpression) : texpression =
@@ -383,14 +425,16 @@ let mk_switch (scrut : texpression) (sb : switch_body) : texpression =
- if there is > one type: wrap them in a tuple
*)
let mk_simpl_tuple_ty (tys : ty list) : ty =
- match tys with [ ty ] -> ty | _ -> Adt (Tuple, tys, [])
+ match tys with
+ | [ ty ] -> ty
+ | _ -> Adt (Tuple, mk_generic_args_from_types tys)
let mk_bool_ty : ty = Literal Bool
-let mk_unit_ty : ty = Adt (Tuple, [], [])
+let mk_unit_ty : ty = Adt (Tuple, empty_generic_args)
let mk_unit_rvalue : texpression =
let id = AdtCons { adt_id = Tuple; variant_id = None } in
- let qualif = { id; type_args = []; const_generic_args = [] } in
+ let qualif = { id; generics = empty_generic_args } in
let e = Qualif qualif in
let ty = mk_unit_ty in
{ e; ty }
@@ -430,7 +474,7 @@ let mk_simpl_tuple_pattern (vl : typed_pattern list) : typed_pattern =
| [ v ] -> v
| _ ->
let tys = List.map (fun (v : typed_pattern) -> v.ty) vl in
- let ty = Adt (Tuple, tys, []) in
+ let ty = Adt (Tuple, mk_generic_args_from_types tys) in
let value = PatAdt { variant_id = None; field_values = vl } in
{ value; ty }
@@ -441,11 +485,11 @@ let mk_simpl_tuple_texpression (vl : texpression list) : texpression =
| _ ->
(* Compute the types of the fields, and the type of the tuple constructor *)
let tys = List.map (fun (v : texpression) -> v.ty) vl in
- let ty = Adt (Tuple, tys, []) in
+ let ty = Adt (Tuple, mk_generic_args_from_types tys) in
let ty = mk_arrows tys ty in
(* Construct the tuple constructor qualifier *)
let id = AdtCons { adt_id = Tuple; variant_id = None } in
- let qualif = { id; type_args = tys; const_generic_args = [] } in
+ let qualif = { id; generics = mk_generic_args_from_types tys } in
(* Put everything together *)
let cons = { e = Qualif qualif; ty } in
mk_apps cons vl
@@ -463,32 +507,36 @@ let ty_as_integer (t : ty) : T.integer_type =
let ty_as_literal (t : ty) : T.literal_type =
match t with Literal ty -> ty | _ -> raise (Failure "Unreachable")
-let mk_state_ty : ty = Adt (Assumed State, [], [])
-let mk_result_ty (ty : ty) : ty = Adt (Assumed Result, [ ty ], [])
-let mk_error_ty : ty = Adt (Assumed Error, [], [])
-let mk_fuel_ty : ty = Adt (Assumed Fuel, [], [])
+let mk_state_ty : ty = Adt (Assumed State, empty_generic_args)
+
+let mk_result_ty (ty : ty) : ty =
+ Adt (Assumed Result, mk_generic_args_from_types [ ty ])
+
+let mk_error_ty : ty = Adt (Assumed Error, empty_generic_args)
+let mk_fuel_ty : ty = Adt (Assumed Fuel, empty_generic_args)
let mk_error (error : VariantId.id) : texpression =
let ty = mk_error_ty in
let id = AdtCons { adt_id = Assumed Error; variant_id = Some error } in
- let qualif = { id; type_args = []; const_generic_args = [] } in
+ let qualif = { id; generics = empty_generic_args } in
let e = Qualif qualif in
{ e; ty }
let unwrap_result_ty (ty : ty) : ty =
match ty with
- | Adt (Assumed Result, [ ty ], cgs) ->
- assert (cgs = []);
+ | Adt
+ (Assumed Result, { types = [ ty ]; const_generics = []; trait_refs = [] })
+ ->
ty
| _ -> raise (Failure "not a result type")
let mk_result_fail_texpression (error : texpression) (ty : ty) : texpression =
let type_args = [ ty ] in
- let ty = Adt (Assumed Result, type_args, []) in
+ let ty = Adt (Assumed Result, mk_generic_args_from_types type_args) in
let id =
AdtCons { adt_id = Assumed Result; variant_id = Some result_fail_id }
in
- let qualif = { id; type_args; const_generic_args = [] } in
+ let qualif = { id; generics = mk_generic_args_from_types type_args } in
let cons_e = Qualif qualif in
let cons_ty = mk_arrow error.ty ty in
let cons = { e = cons_e; ty = cons_ty } in
@@ -501,11 +549,11 @@ let mk_result_fail_texpression_with_error_id (error : VariantId.id) (ty : ty) :
let mk_result_return_texpression (v : texpression) : texpression =
let type_args = [ v.ty ] in
- let ty = Adt (Assumed Result, type_args, []) in
+ let ty = Adt (Assumed Result, mk_generic_args_from_types type_args) in
let id =
AdtCons { adt_id = Assumed Result; variant_id = Some result_return_id }
in
- let qualif = { id; type_args; const_generic_args = [] } in
+ let qualif = { id; generics = mk_generic_args_from_types type_args } in
let cons_e = Qualif qualif in
let cons_ty = mk_arrow v.ty ty in
let cons = { e = cons_e; ty = cons_ty } in
@@ -514,7 +562,7 @@ let mk_result_return_texpression (v : texpression) : texpression =
(** Create a [Fail err] pattern which captures the error *)
let mk_result_fail_pattern (error_pat : pattern) (ty : ty) : typed_pattern =
let error_pat : typed_pattern = { value = error_pat; ty = mk_error_ty } in
- let ty = Adt (Assumed Result, [ ty ], []) in
+ let ty = Adt (Assumed Result, mk_generic_args_from_types [ ty ]) in
let value =
PatAdt { variant_id = Some result_fail_id; field_values = [ error_pat ] }
in
@@ -526,7 +574,7 @@ let mk_result_fail_pattern_ignore_error (ty : ty) : typed_pattern =
mk_result_fail_pattern error_pat ty
let mk_result_return_pattern (v : typed_pattern) : typed_pattern =
- let ty = Adt (Assumed Result, [ v.ty ], []) in
+ let ty = Adt (Assumed Result, mk_generic_args_from_types [ v.ty ]) in
let value =
PatAdt { variant_id = Some result_return_id; field_values = [ v ] }
in
@@ -561,11 +609,11 @@ let rec typed_pattern_to_texpression (pat : typed_pattern) : texpression option
let fields_values = List.map (fun e -> Option.get e) fields in
(* Retrieve the type id and the type args from the pat type (simpler this way *)
- let adt_id, type_args, const_generic_args = ty_as_adt pat.ty in
+ let adt_id, generics = ty_as_adt pat.ty in
(* Create the constructor *)
let qualif_id = AdtCons { adt_id; variant_id = av.variant_id } in
- let qualif = { id = qualif_id; type_args; const_generic_args } in
+ let qualif = { id = qualif_id; generics } in
let cons_e = Qualif qualif in
let field_tys =
List.map (fun (v : texpression) -> v.ty) fields_values
@@ -577,3 +625,55 @@ let rec typed_pattern_to_texpression (pat : typed_pattern) : texpression option
Some (mk_apps cons fields_values).e
in
match e_opt with None -> None | Some e -> Some { e; ty = pat.ty }
+
+type trait_decl_method_decl_id = { is_provided : bool; id : fun_decl_id }
+
+let trait_decl_get_method (trait_decl : trait_decl) (method_name : string) :
+ trait_decl_method_decl_id =
+ (* First look in the required methods *)
+ let method_id =
+ List.find_opt (fun (s, _) -> s = method_name) trait_decl.required_methods
+ in
+ match method_id with
+ | Some (_, id) -> { is_provided = false; id }
+ | None ->
+ (* Must be a provided method *)
+ let _, id =
+ List.find (fun (s, _) -> s = method_name) trait_decl.provided_methods
+ in
+ { is_provided = true; id = Option.get id }
+
+let trait_decl_is_empty (trait_decl : trait_decl) : bool =
+ let {
+ def_id = _;
+ name = _;
+ generics = _;
+ preds = _;
+ parent_clauses;
+ consts;
+ types;
+ required_methods;
+ provided_methods;
+ } =
+ trait_decl
+ in
+ parent_clauses = [] && consts = [] && types = [] && required_methods = []
+ && provided_methods = []
+
+let trait_impl_is_empty (trait_impl : trait_impl) : bool =
+ let {
+ def_id = _;
+ name = _;
+ impl_trait = _;
+ generics = _;
+ preds = _;
+ parent_trait_refs;
+ consts;
+ types;
+ required_methods;
+ provided_methods;
+ } =
+ trait_impl
+ in
+ parent_trait_refs = [] && consts = [] && types = [] && required_methods = []
+ && provided_methods = []
diff --git a/compiler/ReorderDecls.ml b/compiler/ReorderDecls.ml
index fc4744bc..10b68da3 100644
--- a/compiler/ReorderDecls.ml
+++ b/compiler/ReorderDecls.ml
@@ -38,14 +38,16 @@ let compute_body_fun_deps (e : texpression) : FunIdSet.t =
method! visit_qualif _ id =
match id.id with
- | FunOrOp (Unop _ | Binop _) | Global _ | AdtCons _ | Proj _ -> ()
+ | FunOrOp (Unop _ | Binop _)
+ | Global _ | AdtCons _ | Proj _ | TraitConst _ ->
+ ()
| FunOrOp (Fun fid) -> (
match fid with
| Pure _ -> ()
| FromLlbc (fid, lp_id, rg_id) -> (
match fid with
- | Assumed _ -> ()
- | Regular fid ->
+ | FunId (Assumed _) -> ()
+ | TraitMethod (_, _, fid) | FunId (Regular fid) ->
let id = { def_id = fid; lp_id; rg_id } in
ids := FunIdSet.add id !ids))
end
diff --git a/compiler/Substitute.ml b/compiler/Substitute.ml
index 38850243..23f618e2 100644
--- a/compiler/Substitute.ml
+++ b/compiler/Substitute.ml
@@ -9,51 +9,70 @@ module E = Expressions
module A = LlbcAst
module C = Contexts
-(** Substitute types variables and regions in a type. *)
-let ty_substitute (rsubst : 'r1 -> 'r2) (tsubst : T.TypeVarId.id -> 'r2 T.ty)
- (cgsubst : T.ConstGenericVarId.id -> T.const_generic) (ty : 'r1 T.ty) :
- 'r2 T.ty =
- let open T in
- let visitor =
- object
- inherit [_] map_ty
- method visit_'r _ r = rsubst r
- method! visit_TypeVar _ id = tsubst id
+type ('r1, 'r2) subst = {
+ r_subst : 'r1 -> 'r2;
+ ty_subst : T.TypeVarId.id -> 'r2 T.ty;
+ cg_subst : T.ConstGenericVarId.id -> T.const_generic;
+ (** Substitution from *local* trait clause to trait instance *)
+ tr_subst : T.TraitClauseId.id -> 'r2 T.trait_instance_id;
+ (** Substitution for the [Self] trait instance *)
+ tr_self : 'r2 T.trait_instance_id;
+}
+
+let ty_substitute_visitor (subst : ('r1, 'r2) subst) =
+ object
+ inherit [_] T.map_ty
+ method visit_'r _ r = subst.r_subst r
+ method! visit_TypeVar _ id = subst.ty_subst id
- method! visit_type_var_id _ _ =
- (* We should never get here because we reimplemented [visit_TypeVar] *)
- raise (Failure "Unexpected")
+ method! visit_type_var_id _ _ =
+ (* We should never get here because we reimplemented [visit_TypeVar] *)
+ raise (Failure "Unexpected")
- method! visit_ConstGenericVar _ id = cgsubst id
+ method! visit_ConstGenericVar _ id = subst.cg_subst id
- method! visit_const_generic_var_id _ _ =
- (* We should never get here because we reimplemented [visit_Var] *)
- raise (Failure "Unexpected")
- end
- in
+ method! visit_const_generic_var_id _ _ =
+ (* We should never get here because we reimplemented [visit_Var] *)
+ raise (Failure "Unexpected")
- visitor#visit_ty () ty
+ method! visit_Clause _ id = subst.tr_subst id
+ method! visit_Self _ = subst.tr_self
+ end
-let rty_substitute (rsubst : T.RegionId.id -> T.RegionId.id)
- (tsubst : T.TypeVarId.id -> T.rty)
- (cgsubst : T.ConstGenericVarId.id -> T.const_generic) (ty : T.rty) : T.rty =
- let rsubst r =
- match r with T.Static -> T.Static | T.Var rid -> T.Var (rsubst rid)
- in
- ty_substitute rsubst tsubst cgsubst ty
+(** Substitute types variables and regions in a type.
-let ety_substitute (tsubst : T.TypeVarId.id -> T.ety)
- (cgsubst : T.ConstGenericVarId.id -> T.const_generic) (ty : T.ety) : T.ety =
- let rsubst r = r in
- ty_substitute rsubst tsubst cgsubst ty
+ **IMPORTANT**: this doesn't normalize the types.
+ *)
+let ty_substitute (subst : ('r1, 'r2) subst) (ty : 'r1 T.ty) : 'r2 T.ty =
+ let visitor = ty_substitute_visitor subst in
+ visitor#visit_ty () ty
+
+(** **IMPORTANT**: this doesn't normalize the types. *)
+let trait_ref_substitute (subst : ('r1, 'r2) subst) (tr : 'r1 T.trait_ref) :
+ 'r2 T.trait_ref =
+ let visitor = ty_substitute_visitor subst in
+ visitor#visit_trait_ref () tr
+
+(** **IMPORTANT**: this doesn't normalize the types. *)
+let generic_args_substitute (subst : ('r1, 'r2) subst) (g : 'r1 T.generic_args)
+ : 'r2 T.generic_args =
+ let visitor = ty_substitute_visitor subst in
+ visitor#visit_generic_args () g
+
+let erase_regions_subst : ('r, T.erased_region) subst =
+ {
+ r_subst = (fun _ -> T.Erased);
+ ty_subst = (fun vid -> T.TypeVar vid);
+ cg_subst = (fun id -> T.ConstGenericVar id);
+ tr_subst = (fun id -> T.Clause id);
+ tr_self = T.Self;
+ }
(** Convert an {!T.rty} to an {!T.ety} by erasing the region variables *)
-let erase_regions (ty : T.rty) : T.ety =
- ty_substitute
- (fun _ -> T.Erased)
- (fun vid -> T.TypeVar vid)
- (fun id -> T.ConstGenericVar id)
- ty
+let erase_regions (ty : 'r T.ty) : T.ety = ty_substitute erase_regions_subst ty
+
+let trait_ref_erase_regions (tr : 'r T.trait_ref) : T.etrait_ref =
+ trait_ref_substitute erase_regions_subst tr
(** Generate fresh regions for region variables.
@@ -78,18 +97,20 @@ let fresh_regions_with_substs (region_vars : T.region_var list) :
(* Generate the substitution from region var id to region *)
let rid_subst id = T.RegionVarId.Map.find id rid_map in
(* Generate the substitution from region to region *)
- let rsubst r =
+ let r_subst r =
match r with T.Static -> T.Static | T.Var id -> T.Var (rid_subst id)
in
(* Return *)
- (fresh_region_ids, rid_subst, rsubst)
+ (fresh_region_ids, rid_subst, r_subst)
-(** Erase the regions in a type and substitute the type variables *)
-let erase_regions_substitute_types (tsubst : T.TypeVarId.id -> T.ety)
- (cgsubst : T.ConstGenericVarId.id -> T.const_generic)
- (ty : 'r T.region T.ty) : T.ety =
- let rsubst (_ : 'r T.region) : T.erased_region = T.Erased in
- ty_substitute rsubst tsubst cgsubst ty
+(** Erase the regions in a type and perform a substitution *)
+let erase_regions_substitute_types (ty_subst : T.TypeVarId.id -> T.ety)
+ (cg_subst : T.ConstGenericVarId.id -> T.const_generic)
+ (tr_subst : T.TraitClauseId.id -> T.etrait_instance_id)
+ (tr_self : T.etrait_instance_id) (ty : 'r T.ty) : T.ety =
+ let r_subst (_ : 'r) : T.erased_region = T.Erased in
+ let subst = { r_subst; ty_subst; cg_subst; tr_subst; tr_self } in
+ ty_substitute subst ty
(** Create a region substitution from a list of region variable ids and a list of
regions (with which to substitute the region variable ids *)
@@ -146,16 +167,81 @@ let make_const_generic_subst_from_vars (vars : T.const_generic_var list)
(List.map (fun (x : T.const_generic_var) -> x.T.index) vars)
cgs
-(** Instantiate the type variables in an ADT definition, and return, for
- every variant, the list of the types of its fields *)
-let type_decl_get_instantiated_variants_fields_rtypes (def : T.type_decl)
- (regions : T.RegionId.id T.region list) (types : T.rty list)
- (cgs : T.const_generic list) : (T.VariantId.id option * T.rty list) list =
- let r_subst = make_region_subst_from_vars def.T.region_params regions in
- let ty_subst = make_type_subst_from_vars def.T.type_params types in
+(** Create a trait substitution from a list of trait clause ids and a list of
+ trait refs *)
+let make_trait_subst (clause_ids : T.TraitClauseId.id list)
+ (trs : 'r T.trait_ref list) : T.TraitClauseId.id -> 'r T.trait_instance_id =
+ let ls = List.combine clause_ids trs in
+ let mp =
+ List.fold_left
+ (fun mp (k, v) -> T.TraitClauseId.Map.add k (T.TraitRef v) mp)
+ T.TraitClauseId.Map.empty ls
+ in
+ fun id -> T.TraitClauseId.Map.find id mp
+
+let make_trait_subst_from_clauses (clauses : T.trait_clause list)
+ (trs : 'r T.trait_ref list) : T.TraitClauseId.id -> 'r T.trait_instance_id =
+ make_trait_subst
+ (List.map (fun (x : T.trait_clause) -> x.T.clause_id) clauses)
+ trs
+
+let make_subst_from_generics (params : T.generic_params)
+ (args : 'r T.region T.generic_args)
+ (tr_self : 'r T.region T.trait_instance_id) :
+ (T.region_var_id T.region, 'r T.region) subst =
+ let r_subst = make_region_subst_from_vars params.T.regions args.T.regions in
+ let ty_subst = make_type_subst_from_vars params.T.types args.T.types in
let cg_subst =
- make_const_generic_subst_from_vars def.T.const_generic_params cgs
+ make_const_generic_subst_from_vars params.T.const_generics
+ args.T.const_generics
+ in
+ let tr_subst =
+ make_trait_subst_from_clauses params.T.trait_clauses args.T.trait_refs
+ in
+ { r_subst; ty_subst; cg_subst; tr_subst; tr_self }
+
+let make_subst_from_generics_no_regions :
+ 'r.
+ T.generic_params ->
+ 'r T.generic_args ->
+ 'r T.trait_instance_id ->
+ (T.region_var_id T.region, 'r) subst =
+ fun params args tr_self ->
+ let r_subst _ = raise (Failure "Unexpected region") in
+ let ty_subst = make_type_subst_from_vars params.T.types args.T.types in
+ let cg_subst =
+ make_const_generic_subst_from_vars params.T.const_generics
+ args.T.const_generics
+ in
+ let tr_subst =
+ make_trait_subst_from_clauses params.T.trait_clauses args.T.trait_refs
+ in
+ { r_subst; ty_subst; cg_subst; tr_subst; tr_self }
+
+let make_esubst_from_generics (params : T.generic_params)
+ (generics : T.egeneric_args) (tr_self : T.etrait_instance_id) =
+ let r_subst _ = T.Erased in
+ let ty_subst = make_type_subst_from_vars params.types generics.T.types in
+ let cg_subst =
+ make_const_generic_subst_from_vars params.const_generics
+ generics.T.const_generics
+ in
+ let tr_subst =
+ make_trait_subst_from_clauses params.trait_clauses generics.T.trait_refs
in
+ { r_subst; ty_subst; cg_subst; tr_subst; tr_self }
+
+(** Instantiate the type variables in an ADT definition, and return, for
+ every variant, the list of the types of its fields.
+
+ **IMPORTANT**: this function doesn't normalize the types, you may want to
+ use the [AssociatedTypes] equivalent instead.
+*)
+let type_decl_get_instantiated_variants_fields_rtypes (def : T.type_decl)
+ (generics : T.rgeneric_args) : (T.VariantId.id option * T.rty list) list =
+ (* There shouldn't be any reference to Self *)
+ let tr_self = T.UnknownTrait __FUNCTION__ in
+ let subst = make_subst_from_generics def.T.generics generics tr_self in
let (variants_fields : (T.VariantId.id option * T.field list) list) =
match def.T.kind with
| T.Enum variants ->
@@ -171,191 +257,220 @@ let type_decl_get_instantiated_variants_fields_rtypes (def : T.type_decl)
in
List.map
(fun (id, fields) ->
- ( id,
- List.map
- (fun f -> ty_substitute r_subst ty_subst cg_subst f.T.field_ty)
- fields ))
+ (id, List.map (fun f -> ty_substitute subst f.T.field_ty) fields))
variants_fields
(** Instantiate the type variables in an ADT definition, and return the list
- of types of the fields for the chosen variant *)
+ of types of the fields for the chosen variant.
+
+ **IMPORTANT**: this function doesn't normalize the types, you may want to
+ use the [AssociatedTypes] equivalent instead.
+*)
let type_decl_get_instantiated_field_rtypes (def : T.type_decl)
- (opt_variant_id : T.VariantId.id option)
- (regions : T.RegionId.id T.region list) (types : T.rty list)
- (cgs : T.const_generic list) : T.rty list =
- let r_subst = make_region_subst_from_vars def.T.region_params regions in
- let ty_subst = make_type_subst_from_vars def.T.type_params types in
- let cg_subst =
- make_const_generic_subst_from_vars def.T.const_generic_params cgs
- in
+ (opt_variant_id : T.VariantId.id option) (generics : T.rgeneric_args) :
+ T.rty list =
+ (* For now, check that there are no clauses - otherwise we might need
+ to normalize the types *)
+ assert (def.generics.trait_clauses = []);
+ (* There shouldn't be any reference to Self *)
+ let tr_self = T.UnknownTrait __FUNCTION__ in
+ let subst = make_subst_from_generics def.T.generics generics tr_self in
let fields = TU.type_decl_get_fields def opt_variant_id in
- List.map
- (fun f -> ty_substitute r_subst ty_subst cg_subst f.T.field_ty)
- fields
+ List.map (fun f -> ty_substitute subst f.T.field_ty) fields
(** Return the types of the properly instantiated ADT's variant, provided a
- context *)
+ context.
+
+ **IMPORTANT**: this function doesn't normalize the types, you may want to
+ use the [AssociatedTypes] equivalent instead.
+*)
let ctx_adt_get_instantiated_field_rtypes (ctx : C.eval_ctx)
(def_id : T.TypeDeclId.id) (opt_variant_id : T.VariantId.id option)
- (regions : T.RegionId.id T.region list) (types : T.rty list)
- (cgs : T.const_generic list) : T.rty list =
+ (generics : T.rgeneric_args) : T.rty list =
let def = C.ctx_lookup_type_decl ctx def_id in
- type_decl_get_instantiated_field_rtypes def opt_variant_id regions types cgs
+ type_decl_get_instantiated_field_rtypes def opt_variant_id generics
(** Return the types of the properly instantiated ADT value (note that
- here, ADT is understood in its broad meaning: ADT, assumed value or tuple) *)
+ here, ADT is understood in its broad meaning: ADT, assumed value or tuple).
+
+ **IMPORTANT**: this function doesn't normalize the types, you may want to
+ use the [AssociatedTypes] equivalent instead.
+ *)
let ctx_adt_value_get_instantiated_field_rtypes (ctx : C.eval_ctx)
- (adt : V.adt_value) (id : T.type_id)
- (region_params : T.RegionId.id T.region list) (type_params : T.rty list)
- (cg_params : T.const_generic list) : T.rty list =
+ (adt : V.adt_value) (id : T.type_id) (generics : T.rgeneric_args) :
+ T.rty list =
match id with
| T.AdtId id ->
(* Retrieve the types of the fields *)
- ctx_adt_get_instantiated_field_rtypes ctx id adt.V.variant_id
- region_params type_params cg_params
+ ctx_adt_get_instantiated_field_rtypes ctx id adt.V.variant_id generics
| T.Tuple ->
- assert (List.length region_params = 0);
- type_params
+ assert (generics.regions = []);
+ generics.types
| T.Assumed aty -> (
match aty with
- | T.Box | T.Vec ->
- assert (List.length region_params = 0);
- assert (List.length type_params = 1);
- assert (List.length cg_params = 0);
- type_params
- | T.Option ->
- assert (List.length region_params = 0);
- assert (List.length type_params = 1);
- assert (List.length cg_params = 0);
- if adt.V.variant_id = Some T.option_some_id then type_params
- else if adt.V.variant_id = Some T.option_none_id then []
- else raise (Failure "Unreachable")
- | T.Range ->
- assert (List.length region_params = 0);
- assert (List.length type_params = 1);
- assert (List.length cg_params = 0);
- type_params
+ | T.Box ->
+ assert (generics.regions = []);
+ assert (List.length generics.types = 1);
+ assert (generics.const_generics = []);
+ generics.types
| T.Array | T.Slice | T.Str ->
(* Those types don't have fields *)
raise (Failure "Unreachable"))
(** Instantiate the type variables in an ADT definition, and return the list
- of types of the fields for the chosen variant *)
+ of types of the fields for the chosen variant.
+
+ **IMPORTANT**: this function doesn't normalize the types, you may want to
+ use the [AssociatedTypes] equivalent instead.
+*)
let type_decl_get_instantiated_field_etypes (def : T.type_decl)
- (opt_variant_id : T.VariantId.id option) (types : T.ety list)
- (cgs : T.const_generic list) : T.ety list =
- let ty_subst = make_type_subst_from_vars def.T.type_params types in
- let cg_subst =
- make_const_generic_subst_from_vars def.T.const_generic_params cgs
+ (opt_variant_id : T.VariantId.id option) (generics : T.egeneric_args) :
+ T.ety list =
+ (* For now, check that there are no clauses - otherwise we might need
+ to normalize the types *)
+ assert (def.generics.trait_clauses = []);
+ (* There shouldn't be any reference to Self *)
+ let tr_self : T.erased_region T.trait_instance_id =
+ T.UnknownTrait __FUNCTION__
+ in
+ let { r_subst = _; ty_subst; cg_subst; tr_subst; tr_self } =
+ make_esubst_from_generics def.T.generics generics tr_self
in
let fields = TU.type_decl_get_fields def opt_variant_id in
List.map
- (fun f -> erase_regions_substitute_types ty_subst cg_subst f.T.field_ty)
+ (fun (f : T.field) ->
+ erase_regions_substitute_types ty_subst cg_subst tr_subst tr_self
+ f.T.field_ty)
fields
(** Return the types of the properly instantiated ADT's variant, provided a
- context *)
+ context.
+
+ **IMPORTANT**: this function doesn't normalize the types, you may want to
+ use the [AssociatedTypes] equivalent instead.
+ *)
let ctx_adt_get_instantiated_field_etypes (ctx : C.eval_ctx)
(def_id : T.TypeDeclId.id) (opt_variant_id : T.VariantId.id option)
- (types : T.ety list) (cgs : T.const_generic list) : T.ety list =
+ (generics : T.egeneric_args) : T.ety list =
let def = C.ctx_lookup_type_decl ctx def_id in
- type_decl_get_instantiated_field_etypes def opt_variant_id types cgs
+ type_decl_get_instantiated_field_etypes def opt_variant_id generics
-let statement_substitute_visitor (tsubst : T.TypeVarId.id -> T.ety)
- (cgsubst : T.ConstGenericVarId.id -> T.const_generic) =
+let statement_substitute_visitor
+ (subst : (T.erased_region, T.erased_region) subst) =
+ (* Keep in synch with [ty_substitute_visitor] *)
object
inherit [_] A.map_statement
- method! visit_ety _ ty = ety_substitute tsubst cgsubst ty
- method! visit_ConstGenericVar _ id = cgsubst id
+ method! visit_'r _ r = subst.r_subst r
+ method! visit_TypeVar _ id = subst.ty_subst id
+
+ method! visit_type_var_id _ _ =
+ (* We should never get here because we reimplemented [visit_TypeVar] *)
+ raise (Failure "Unexpected")
+
+ method! visit_ConstGenericVar _ id = subst.cg_subst id
method! visit_const_generic_var_id _ _ =
(* We should never get here because we reimplemented [visit_Var] *)
raise (Failure "Unexpected")
+
+ method! visit_Clause _ id = subst.tr_subst id
+ method! visit_Self _ = subst.tr_self
end
(** Apply a type substitution to a place *)
-let place_substitute (tsubst : T.TypeVarId.id -> T.ety)
- (cgsubst : T.ConstGenericVarId.id -> T.const_generic) (p : E.place) :
- E.place =
+let place_substitute (subst : (T.erased_region, T.erased_region) subst)
+ (p : E.place) : E.place =
(* There is in fact nothing to do *)
- (statement_substitute_visitor tsubst cgsubst)#visit_place () p
+ (statement_substitute_visitor subst)#visit_place () p
(** Apply a type substitution to an operand *)
-let operand_substitute (tsubst : T.TypeVarId.id -> T.ety)
- (cgsubst : T.ConstGenericVarId.id -> T.const_generic) (op : E.operand) :
- E.operand =
- (statement_substitute_visitor tsubst cgsubst)#visit_operand () op
+let operand_substitute (subst : (T.erased_region, T.erased_region) subst)
+ (op : E.operand) : E.operand =
+ (statement_substitute_visitor subst)#visit_operand () op
(** Apply a type substitution to an rvalue *)
-let rvalue_substitute (tsubst : T.TypeVarId.id -> T.ety)
- (cgsubst : T.ConstGenericVarId.id -> T.const_generic) (rv : E.rvalue) :
- E.rvalue =
- (statement_substitute_visitor tsubst cgsubst)#visit_rvalue () rv
+let rvalue_substitute (subst : (T.erased_region, T.erased_region) subst)
+ (rv : E.rvalue) : E.rvalue =
+ (statement_substitute_visitor subst)#visit_rvalue () rv
(** Apply a type substitution to an assertion *)
-let assertion_substitute (tsubst : T.TypeVarId.id -> T.ety)
- (cgsubst : T.ConstGenericVarId.id -> T.const_generic) (a : A.assertion) :
- A.assertion =
- (statement_substitute_visitor tsubst cgsubst)#visit_assertion () a
+let assertion_substitute (subst : (T.erased_region, T.erased_region) subst)
+ (a : A.assertion) : A.assertion =
+ (statement_substitute_visitor subst)#visit_assertion () a
(** Apply a type substitution to a call *)
-let call_substitute (tsubst : T.TypeVarId.id -> T.ety)
- (cgsubst : T.ConstGenericVarId.id -> T.const_generic) (call : A.call) :
- A.call =
- (statement_substitute_visitor tsubst cgsubst)#visit_call () call
+let call_substitute (subst : (T.erased_region, T.erased_region) subst)
+ (call : A.call) : A.call =
+ (statement_substitute_visitor subst)#visit_call () call
(** Apply a type substitution to a statement *)
-let statement_substitute (tsubst : T.TypeVarId.id -> T.ety)
- (cgsubst : T.ConstGenericVarId.id -> T.const_generic) (st : A.statement) :
- A.statement =
- (statement_substitute_visitor tsubst cgsubst)#visit_statement () st
+let statement_substitute (subst : (T.erased_region, T.erased_region) subst)
+ (st : A.statement) : A.statement =
+ (statement_substitute_visitor subst)#visit_statement () st
(** Apply a type substitution to a function body. Return the local variables
and the body. *)
-let fun_body_substitute_in_body (tsubst : T.TypeVarId.id -> T.ety)
- (cgsubst : T.ConstGenericVarId.id -> T.const_generic) (body : A.fun_body) :
+let fun_body_substitute_in_body
+ (subst : (T.erased_region, T.erased_region) subst) (body : A.fun_body) :
A.var list * A.statement =
- let rsubst r = r in
let locals =
List.map
- (fun (v : A.var) ->
- { v with A.var_ty = ty_substitute rsubst tsubst cgsubst v.A.var_ty })
+ (fun (v : A.var) -> { v with A.var_ty = ty_substitute subst v.A.var_ty })
body.A.locals
in
- let body = statement_substitute tsubst cgsubst body.body in
+ let body = statement_substitute subst body.body in
(locals, body)
-(** Substitute a function signature *)
+let trait_type_constraint_substitute (subst : ('r1, 'r2) subst)
+ (ttc : 'r1 T.trait_type_constraint) : 'r2 T.trait_type_constraint =
+ let { T.trait_ref; generics; type_name; ty } = ttc in
+ let visitor = ty_substitute_visitor subst in
+ let trait_ref = visitor#visit_trait_ref () trait_ref in
+ let generics = visitor#visit_generic_args () generics in
+ let ty = visitor#visit_ty () ty in
+ { T.trait_ref; generics; type_name; ty }
+
+(** Substitute a function signature.
+
+ **IMPORTANT:** this function doesn't normalize the types.
+ *)
let substitute_signature (asubst : T.RegionGroupId.id -> V.AbstractionId.id)
- (rsubst : T.RegionVarId.id -> T.RegionId.id)
- (tsubst : T.TypeVarId.id -> T.rty)
- (cgsubst : T.ConstGenericVarId.id -> T.const_generic) (sg : A.fun_sig) :
- A.inst_fun_sig =
- let rsubst' (r : T.RegionVarId.id T.region) : T.RegionId.id T.region =
- match r with T.Static -> T.Static | T.Var rid -> T.Var (rsubst rid)
+ (r_subst : T.RegionVarId.id -> T.RegionId.id)
+ (ty_subst : T.TypeVarId.id -> T.rty)
+ (cg_subst : T.ConstGenericVarId.id -> T.const_generic)
+ (tr_subst : T.TraitClauseId.id -> T.rtrait_instance_id)
+ (tr_self : T.rtrait_instance_id) (sg : A.fun_sig) : A.inst_fun_sig =
+ let r_subst' (r : T.RegionVarId.id T.region) : T.RegionId.id T.region =
+ match r with T.Static -> T.Static | T.Var rid -> T.Var (r_subst rid)
in
- let inputs = List.map (ty_substitute rsubst' tsubst cgsubst) sg.A.inputs in
- let output = ty_substitute rsubst' tsubst cgsubst sg.A.output in
+ let subst = { r_subst = r_subst'; ty_subst; cg_subst; tr_subst; tr_self } in
+ let inputs = List.map (ty_substitute subst) sg.A.inputs in
+ let output = ty_substitute subst sg.A.output in
let subst_region_group (rg : T.region_var_group) : A.abs_region_group =
let id = asubst rg.id in
- let regions = List.map rsubst rg.regions in
+ let regions = List.map r_subst rg.regions in
let parents = List.map asubst rg.parents in
{ id; regions; parents }
in
let regions_hierarchy = List.map subst_region_group sg.A.regions_hierarchy in
- { A.regions_hierarchy; inputs; output }
+ let trait_type_constraints =
+ List.map
+ (trait_type_constraint_substitute subst)
+ sg.preds.trait_type_constraints
+ in
+ { A.inputs; output; regions_hierarchy; trait_type_constraints }
-(** Substitute type variable identifiers in a type *)
-let ty_substitute_ids (tsubst : T.TypeVarId.id -> T.TypeVarId.id)
- (cgsubst : T.ConstGenericVarId.id -> T.ConstGenericVarId.id) (ty : 'r T.ty)
+(** Substitute variable identifiers in a type *)
+let ty_substitute_ids (ty_subst : T.TypeVarId.id -> T.TypeVarId.id)
+ (cg_subst : T.ConstGenericVarId.id -> T.ConstGenericVarId.id) (ty : 'r T.ty)
: 'r T.ty =
let open T in
let visitor =
object
inherit [_] map_ty
method visit_'r _ r = r
- method! visit_type_var_id _ id = tsubst id
- method! visit_const_generic_var_id _ id = cgsubst id
+ method! visit_type_var_id _ id = ty_subst id
+ method! visit_const_generic_var_id _ id = cg_subst id
end
in
@@ -371,10 +486,10 @@ let ty_substitute_ids (tsubst : T.TypeVarId.id -> T.TypeVarId.id)
[visit_'r] if we define a class which visits objects of types [ety] and [rty]
while inheriting a class which visit [ty]...
*)
-let subst_ids_visitor (rsubst : T.RegionId.id -> T.RegionId.id)
+let subst_ids_visitor (r_subst : T.RegionId.id -> T.RegionId.id)
(rvsubst : T.RegionVarId.id -> T.RegionVarId.id)
- (tsubst : T.TypeVarId.id -> T.TypeVarId.id)
- (cgsubst : T.ConstGenericVarId.id -> T.ConstGenericVarId.id)
+ (ty_subst : T.TypeVarId.id -> T.TypeVarId.id)
+ (cg_subst : T.ConstGenericVarId.id -> T.ConstGenericVarId.id)
(ssubst : V.SymbolicValueId.id -> V.SymbolicValueId.id)
(bsubst : V.BorrowId.id -> V.BorrowId.id)
(asubst : V.AbstractionId.id -> V.AbstractionId.id) =
@@ -383,10 +498,10 @@ let subst_ids_visitor (rsubst : T.RegionId.id -> T.RegionId.id)
inherit [_] T.map_ty
method visit_'r _ r =
- match r with T.Static -> T.Static | T.Var rid -> T.Var (rsubst rid)
+ match r with T.Static -> T.Static | T.Var rid -> T.Var (r_subst rid)
- method! visit_type_var_id _ id = tsubst id
- method! visit_const_generic_var_id _ id = cgsubst id
+ method! visit_type_var_id _ id = ty_subst id
+ method! visit_const_generic_var_id _ id = cg_subst id
end
in
@@ -395,7 +510,7 @@ let subst_ids_visitor (rsubst : T.RegionId.id -> T.RegionId.id)
inherit [_] C.map_env
method! visit_borrow_id _ bid = bsubst bid
method! visit_loan_id _ bid = bsubst bid
- method! visit_ety _ ty = ty_substitute_ids tsubst cgsubst ty
+ method! visit_ety _ ty = ty_substitute_ids ty_subst cg_subst ty
method! visit_rty env ty = subst_rty#visit_ty env ty
method! visit_symbolic_value_id _ id = ssubst id
@@ -405,7 +520,7 @@ let subst_ids_visitor (rsubst : T.RegionId.id -> T.RegionId.id)
(** We *do* visit meta-values *)
method! visit_mvalue env v = self#visit_typed_value env v
- method! visit_region_id _ id = rsubst id
+ method! visit_region_id _ id = r_subst id
method! visit_region_var_id _ id = rvsubst id
method! visit_abstraction_id _ id = asubst id
end
@@ -425,20 +540,20 @@ let subst_ids_visitor (rsubst : T.RegionId.id -> T.RegionId.id)
method visit_env (env : C.env) : C.env = visitor#visit_env () env
end
-let typed_value_subst_ids (rsubst : T.RegionId.id -> T.RegionId.id)
+let typed_value_subst_ids (r_subst : T.RegionId.id -> T.RegionId.id)
(rvsubst : T.RegionVarId.id -> T.RegionVarId.id)
- (tsubst : T.TypeVarId.id -> T.TypeVarId.id)
- (cgsubst : T.ConstGenericVarId.id -> T.ConstGenericVarId.id)
+ (ty_subst : T.TypeVarId.id -> T.TypeVarId.id)
+ (cg_subst : T.ConstGenericVarId.id -> T.ConstGenericVarId.id)
(ssubst : V.SymbolicValueId.id -> V.SymbolicValueId.id)
(bsubst : V.BorrowId.id -> V.BorrowId.id) (v : V.typed_value) :
V.typed_value =
let asubst _ = raise (Failure "Unreachable") in
- (subst_ids_visitor rsubst rvsubst tsubst cgsubst ssubst bsubst asubst)
+ (subst_ids_visitor r_subst rvsubst ty_subst cg_subst ssubst bsubst asubst)
#visit_typed_value v
-let typed_value_subst_rids (rsubst : T.RegionId.id -> T.RegionId.id)
+let typed_value_subst_rids (r_subst : T.RegionId.id -> T.RegionId.id)
(v : V.typed_value) : V.typed_value =
- typed_value_subst_ids rsubst
+ typed_value_subst_ids r_subst
(fun x -> x)
(fun x -> x)
(fun x -> x)
@@ -446,41 +561,41 @@ let typed_value_subst_rids (rsubst : T.RegionId.id -> T.RegionId.id)
(fun x -> x)
v
-let typed_avalue_subst_ids (rsubst : T.RegionId.id -> T.RegionId.id)
+let typed_avalue_subst_ids (r_subst : T.RegionId.id -> T.RegionId.id)
(rvsubst : T.RegionVarId.id -> T.RegionVarId.id)
- (tsubst : T.TypeVarId.id -> T.TypeVarId.id)
- (cgsubst : T.ConstGenericVarId.id -> T.ConstGenericVarId.id)
+ (ty_subst : T.TypeVarId.id -> T.TypeVarId.id)
+ (cg_subst : T.ConstGenericVarId.id -> T.ConstGenericVarId.id)
(ssubst : V.SymbolicValueId.id -> V.SymbolicValueId.id)
(bsubst : V.BorrowId.id -> V.BorrowId.id) (v : V.typed_avalue) :
V.typed_avalue =
let asubst _ = raise (Failure "Unreachable") in
- (subst_ids_visitor rsubst rvsubst tsubst cgsubst ssubst bsubst asubst)
+ (subst_ids_visitor r_subst rvsubst ty_subst cg_subst ssubst bsubst asubst)
#visit_typed_avalue v
-let abs_subst_ids (rsubst : T.RegionId.id -> T.RegionId.id)
+let abs_subst_ids (r_subst : T.RegionId.id -> T.RegionId.id)
(rvsubst : T.RegionVarId.id -> T.RegionVarId.id)
- (tsubst : T.TypeVarId.id -> T.TypeVarId.id)
- (cgsubst : T.ConstGenericVarId.id -> T.ConstGenericVarId.id)
+ (ty_subst : T.TypeVarId.id -> T.TypeVarId.id)
+ (cg_subst : T.ConstGenericVarId.id -> T.ConstGenericVarId.id)
(ssubst : V.SymbolicValueId.id -> V.SymbolicValueId.id)
(bsubst : V.BorrowId.id -> V.BorrowId.id)
(asubst : V.AbstractionId.id -> V.AbstractionId.id) (x : V.abs) : V.abs =
- (subst_ids_visitor rsubst rvsubst tsubst cgsubst ssubst bsubst asubst)
+ (subst_ids_visitor r_subst rvsubst ty_subst cg_subst ssubst bsubst asubst)
#visit_abs x
-let env_subst_ids (rsubst : T.RegionId.id -> T.RegionId.id)
+let env_subst_ids (r_subst : T.RegionId.id -> T.RegionId.id)
(rvsubst : T.RegionVarId.id -> T.RegionVarId.id)
- (tsubst : T.TypeVarId.id -> T.TypeVarId.id)
- (cgsubst : T.ConstGenericVarId.id -> T.ConstGenericVarId.id)
+ (ty_subst : T.TypeVarId.id -> T.TypeVarId.id)
+ (cg_subst : T.ConstGenericVarId.id -> T.ConstGenericVarId.id)
(ssubst : V.SymbolicValueId.id -> V.SymbolicValueId.id)
(bsubst : V.BorrowId.id -> V.BorrowId.id)
(asubst : V.AbstractionId.id -> V.AbstractionId.id) (x : C.env) : C.env =
- (subst_ids_visitor rsubst rvsubst tsubst cgsubst ssubst bsubst asubst)
+ (subst_ids_visitor r_subst rvsubst ty_subst cg_subst ssubst bsubst asubst)
#visit_env x
-let typed_avalue_subst_rids (rsubst : T.RegionId.id -> T.RegionId.id)
+let typed_avalue_subst_rids (r_subst : T.RegionId.id -> T.RegionId.id)
(x : V.typed_avalue) : V.typed_avalue =
let asubst _ = raise (Failure "Unreachable") in
- (subst_ids_visitor rsubst
+ (subst_ids_visitor r_subst
(fun x -> x)
(fun x -> x)
(fun x -> x)
@@ -490,9 +605,9 @@ let typed_avalue_subst_rids (rsubst : T.RegionId.id -> T.RegionId.id)
#visit_typed_avalue
x
-let env_subst_rids (rsubst : T.RegionId.id -> T.RegionId.id) (x : C.env) : C.env
- =
- (subst_ids_visitor rsubst
+let env_subst_rids (r_subst : T.RegionId.id -> T.RegionId.id) (x : C.env) :
+ C.env =
+ (subst_ids_visitor r_subst
(fun x -> x)
(fun x -> x)
(fun x -> x)
diff --git a/compiler/SymbolicAst.ml b/compiler/SymbolicAst.ml
index 7dc94dcd..4df8fec7 100644
--- a/compiler/SymbolicAst.ml
+++ b/compiler/SymbolicAst.ml
@@ -29,7 +29,7 @@ type mplace = {
[@@deriving show]
type call_id =
- | Fun of A.fun_id * V.FunCallId.id
+ | Fun of A.fun_id_or_trait_method_ref * V.FunCallId.id
(** A "regular" function (i.e., a function which is not a primitive operation) *)
| Unop of E.unop
| Binop of E.binop
@@ -43,10 +43,7 @@ type call = {
borrows (we need to perform lookups).
*)
abstractions : V.AbstractionId.id list;
- (* TODO: rename to "...args" *)
- type_params : T.ety list;
- (* TODO: rename to "...args" *)
- const_generic_params : T.const_generic list;
+ generics : T.egeneric_args;
args : V.typed_value list;
args_places : mplace option list; (** Meta information *)
dest : V.symbolic_value;
@@ -79,6 +76,9 @@ class ['self] iter_expression_base =
method visit_loop_id : 'env -> V.loop_id -> unit = fun _ _ -> ()
method visit_variant_id : 'env -> variant_id -> unit = fun _ _ -> ()
+ method visit_const_generic_var_id : 'env -> T.const_generic_var_id -> unit =
+ fun _ _ -> ()
+
method visit_symbolic_value_id : 'env -> V.symbolic_value_id -> unit =
fun _ _ -> ()
@@ -120,6 +120,9 @@ class ['self] iter_expression_base =
method visit_symbolic_expansion : 'env -> V.symbolic_expansion -> unit =
fun _ _ -> ()
+
+ method visit_etrait_ref : 'env -> T.etrait_ref -> unit = fun _ _ -> ()
+ method visit_egeneric_args : 'env -> T.egeneric_args -> unit = fun _ _ -> ()
end
(** **Rem.:** here, {!expression} is not at all equivalent to the expressions
@@ -171,14 +174,15 @@ type expression =
* expression
(** We introduce a new symbolic value, equal to some other value.
- This is used for instance when reorganizing the environment to compute
- fixed points: we duplicate some shared symbolic values to destructure
- the shared values, in order to make the environment a bit more general
- (while losing precision of course).
+ This is used for instance when reorganizing the environment to compute
+ fixed points: we duplicate some shared symbolic values to destructure
+ the shared values, in order to make the environment a bit more general
+ (while losing precision of course). We also use it to introduce symbolic
+ values when evaluating constant generics, or trait constants.
- The context is the evaluation context from before introducing the new
- value. It has the same purpose as for the {!Return} case.
- *)
+ The context is the evaluation context from before introducing the new
+ value. It has the same purpose as for the {!Return} case.
+ *)
| ForwardEnd of
Contexts.eval_ctx
* V.typed_value symbolic_value_id_map option
@@ -253,6 +257,11 @@ and value_aggregate =
| SingleValue of V.typed_value (** Regular case *)
| Array of V.typed_value list
(** This is used when introducing array aggregates *)
+ | ConstGenericValue of T.const_generic_var_id
+ (** This is used when evaluating a const generic value: in the interpreter,
+ we introduce a fresh symbolic value. *)
+ | TraitConstValue of T.etrait_ref * T.egeneric_args * string
+ (** A trait constant value *)
[@@deriving
show,
visitors
diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml
index 3512270a..2ce8c706 100644
--- a/compiler/SymbolicToPure.ml
+++ b/compiler/SymbolicToPure.ml
@@ -4,6 +4,7 @@ open Pure
open PureUtils
module Id = Identifiers
module C = Contexts
+module A = LlbcAst
module S = SymbolicAst
module TA = TypesAnalysis
module L = Logging
@@ -52,6 +53,9 @@ type fun_context = {
type global_context = { llbc_global_decls : A.global_decl A.GlobalDeclId.Map.t }
[@@deriving show]
+type trait_decls_context = A.trait_decl A.TraitDeclId.Map.t [@@deriving show]
+type trait_impls_context = A.trait_impl A.TraitImplId.Map.t [@@deriving show]
+
(** Whenever we translate a function call or an ended abstraction, we
store the related information (this is useful when translating ended
children abstractions).
@@ -106,8 +110,7 @@ type loop_info = {
loop_id : LoopId.id;
input_vars : var list;
input_svl : V.symbolic_value list;
- type_args : ty list;
- const_generic_args : const_generic list;
+ generics : generic_args;
forward_inputs : texpression list option;
(** The forward inputs are initialized at [None] *)
forward_output_no_state_no_result : var option;
@@ -120,6 +123,8 @@ type bs_ctx = {
type_context : type_context;
fun_context : fun_context;
global_context : global_context;
+ trait_decls_ctx : trait_decls_context;
+ trait_impls_ctx : trait_impls_context;
fun_decl : A.fun_decl;
bid : T.RegionGroupId.id option; (** TODO: rename *)
sg : fun_sig;
@@ -201,34 +206,11 @@ type bs_ctx = {
}
[@@deriving show]
-let type_check_pattern (ctx : bs_ctx) (v : typed_pattern) : unit =
- let env = VarId.Map.empty in
- let ctx =
- {
- PureTypeCheck.type_decls = ctx.type_context.type_decls;
- global_decls = ctx.global_context.llbc_global_decls;
- env;
- }
- in
- let _ = PureTypeCheck.check_typed_pattern ctx v in
- ()
-
-let type_check_texpression (ctx : bs_ctx) (e : texpression) : unit =
- let env = VarId.Map.empty in
- let ctx =
- {
- PureTypeCheck.type_decls = ctx.type_context.type_decls;
- global_decls = ctx.global_context.llbc_global_decls;
- env;
- }
- in
- PureTypeCheck.check_texpression ctx e
-
(* TODO: move *)
let bs_ctx_to_ast_formatter (ctx : bs_ctx) : Print.Ast.ast_formatter =
Print.Ast.decls_and_fun_decl_to_ast_formatter ctx.type_context.llbc_type_decls
ctx.fun_context.llbc_fun_decls ctx.global_context.llbc_global_decls
- ctx.fun_decl
+ ctx.trait_decls_ctx ctx.trait_impls_ctx ctx.fun_decl
let bs_ctx_to_ctx_formatter (ctx : bs_ctx) : Print.Contexts.ctx_formatter =
let rvar_to_string = Print.Types.region_var_id_to_string in
@@ -246,16 +228,25 @@ let bs_ctx_to_ctx_formatter (ctx : bs_ctx) : Print.Contexts.ctx_formatter =
adt_variant_to_string = ast_fmt.adt_variant_to_string;
var_id_to_string;
adt_field_names = ast_fmt.adt_field_names;
+ trait_decl_id_to_string = ast_fmt.trait_decl_id_to_string;
+ trait_impl_id_to_string = ast_fmt.trait_impl_id_to_string;
+ trait_clause_id_to_string = ast_fmt.trait_clause_id_to_string;
}
let bs_ctx_to_pp_ast_formatter (ctx : bs_ctx) : PrintPure.ast_formatter =
- let type_params = ctx.fun_decl.signature.type_params in
- let cg_params = ctx.fun_decl.signature.const_generic_params in
+ let generics = ctx.fun_decl.signature.generics in
let type_decls = ctx.type_context.llbc_type_decls in
let fun_decls = ctx.fun_context.llbc_fun_decls in
let global_decls = ctx.global_context.llbc_global_decls in
- PrintPure.mk_ast_formatter type_decls fun_decls global_decls type_params
- cg_params
+ PrintPure.mk_ast_formatter type_decls fun_decls global_decls
+ ctx.trait_decls_ctx ctx.trait_impls_ctx generics.types
+ generics.const_generics
+
+let ctx_egeneric_args_to_string (ctx : bs_ctx) (args : T.egeneric_args) : string
+ =
+ let fmt = bs_ctx_to_ctx_formatter ctx in
+ let fmt = Print.PC.ctx_to_etype_formatter fmt in
+ Print.PT.egeneric_args_to_string fmt args
let symbolic_value_to_string (ctx : bs_ctx) (sv : V.symbolic_value) : string =
let fmt = bs_ctx_to_ctx_formatter ctx in
@@ -277,12 +268,11 @@ let rty_to_string (ctx : bs_ctx) (ty : T.rty) : string =
Print.PT.rty_to_string fmt ty
let type_decl_to_string (ctx : bs_ctx) (def : type_decl) : string =
- let type_params = def.type_params in
- let cg_params = def.const_generic_params in
let type_decls = ctx.type_context.llbc_type_decls in
let global_decls = ctx.global_context.llbc_global_decls in
let fmt =
- PrintPure.mk_type_formatter type_decls global_decls type_params cg_params
+ PrintPure.mk_type_formatter type_decls global_decls ctx.trait_decls_ctx
+ ctx.trait_impls_ctx def.generics.types def.generics.const_generics
in
PrintPure.type_decl_to_string fmt def
@@ -291,26 +281,27 @@ let texpression_to_string (ctx : bs_ctx) (e : texpression) : string =
PrintPure.texpression_to_string fmt false "" " " e
let fun_sig_to_string (ctx : bs_ctx) (sg : fun_sig) : string =
- let type_params = sg.type_params in
- let cg_params = sg.const_generic_params in
+ let type_params = sg.generics.types in
+ let cg_params = sg.generics.const_generics in
let type_decls = ctx.type_context.llbc_type_decls in
let fun_decls = ctx.fun_context.llbc_fun_decls in
let global_decls = ctx.global_context.llbc_global_decls in
let fmt =
- PrintPure.mk_ast_formatter type_decls fun_decls global_decls type_params
- cg_params
+ PrintPure.mk_ast_formatter type_decls fun_decls global_decls
+ ctx.trait_decls_ctx ctx.trait_impls_ctx type_params cg_params
in
PrintPure.fun_sig_to_string fmt sg
let fun_decl_to_string (ctx : bs_ctx) (def : Pure.fun_decl) : string =
- let type_params = def.signature.type_params in
- let cg_params = def.signature.const_generic_params in
+ let generics = def.signature.generics in
+ let type_params = generics.types in
+ let cg_params = generics.const_generics in
let type_decls = ctx.type_context.llbc_type_decls in
let fun_decls = ctx.fun_context.llbc_fun_decls in
let global_decls = ctx.global_context.llbc_global_decls in
let fmt =
- PrintPure.mk_ast_formatter type_decls fun_decls global_decls type_params
- cg_params
+ PrintPure.mk_ast_formatter type_decls fun_decls global_decls
+ ctx.trait_decls_ctx ctx.trait_impls_ctx type_params cg_params
in
PrintPure.fun_decl_to_string fmt def
@@ -328,17 +319,18 @@ let abs_to_string (ctx : bs_ctx) (abs : V.abs) : string =
Print.Values.abs_to_string fmt verbose indent indent_incr abs
let get_instantiated_fun_sig (fun_id : A.fun_id)
- (back_id : T.RegionGroupId.id option) (tys : ty list)
- (cgs : const_generic list) (ctx : bs_ctx) : inst_fun_sig =
+ (back_id : T.RegionGroupId.id option) (generics : generic_args)
+ (ctx : bs_ctx) : inst_fun_sig =
(* Lookup the non-instantiated function signature *)
let sg =
(RegularFunIdNotLoopMap.find (fun_id, back_id) ctx.fun_context.fun_sigs).sg
in
(* Create the substitution *)
- let tsubst = make_type_subst sg.type_params tys in
- let cgsubst = make_const_generic_subst sg.const_generic_params cgs in
+ (* There shouldn't be any reference to Self *)
+ let tr_self = UnknownTrait __FUNCTION__ in
+ let subst = make_subst_from_generics sg.generics generics tr_self in
(* Apply *)
- fun_sig_substitute tsubst cgsubst sg
+ fun_sig_substitute subst sg
let bs_ctx_lookup_llbc_type_decl (id : TypeDeclId.id) (ctx : bs_ctx) :
T.type_decl =
@@ -351,77 +343,128 @@ let bs_ctx_lookup_llbc_fun_decl (id : A.FunDeclId.id) (ctx : bs_ctx) :
(* TODO: move *)
let bs_ctx_lookup_local_function_sig (def_id : A.FunDeclId.id)
(back_id : T.RegionGroupId.id option) (ctx : bs_ctx) : fun_sig =
- let id = (A.Regular def_id, back_id) in
+ let id = (E.Regular def_id, back_id) in
(RegularFunIdNotLoopMap.find id ctx.fun_context.fun_sigs).sg
-let bs_ctx_register_forward_call (call_id : V.FunCallId.id) (forward : S.call)
- (args : texpression list) (ctx : bs_ctx) : bs_ctx =
- let calls = ctx.calls in
- assert (not (V.FunCallId.Map.mem call_id calls));
- let info =
- { forward; forward_inputs = args; backwards = T.RegionGroupId.Map.empty }
- in
- let calls = V.FunCallId.Map.add call_id info calls in
- { ctx with calls }
-
-(** [back_args]: the *additional* list of inputs received by the backward function *)
-let bs_ctx_register_backward_call (abs : V.abs) (call_id : V.FunCallId.id)
- (back_id : T.RegionGroupId.id) (back_args : texpression list) (ctx : bs_ctx)
- : bs_ctx * fun_or_op_id =
- (* Insert the abstraction in the call informations *)
- let info = V.FunCallId.Map.find call_id ctx.calls in
- assert (not (T.RegionGroupId.Map.mem back_id info.backwards));
- let backwards =
- T.RegionGroupId.Map.add back_id (abs, back_args) info.backwards
- in
- let info = { info with backwards } in
- let calls = V.FunCallId.Map.add call_id info ctx.calls in
- (* Insert the abstraction in the abstractions map *)
- let abstractions = ctx.abstractions in
- assert (not (V.AbstractionId.Map.mem abs.abs_id abstractions));
- let abstractions =
- V.AbstractionId.Map.add abs.abs_id (abs, back_args) abstractions
- in
- (* Retrieve the fun_id *)
- let fun_id =
- match info.forward.call_id with
- | S.Fun (fid, _) -> Fun (FromLlbc (fid, None, Some back_id))
- | S.Unop _ | S.Binop _ -> raise (Failure "Unreachable")
- in
- (* Update the context and return *)
- ({ ctx with calls; abstractions }, fun_id)
+(* Some generic translation functions (we need to translate different "flavours"
+ of types: sty, forward types, backward types, etc.) *)
+let rec translate_generic_args (translate_ty : 'r T.ty -> ty)
+ (generics : 'r T.generic_args) : generic_args =
+ (* We ignore the regions: if they didn't cause trouble for the symbolic execution,
+ then everything's fine *)
+ let types = List.map translate_ty generics.types in
+ let const_generics = generics.const_generics in
+ let trait_refs =
+ List.map (translate_trait_ref translate_ty) generics.trait_refs
+ in
+ { types; const_generics; trait_refs }
+
+and translate_trait_ref (translate_ty : 'r T.ty -> ty) (tr : 'r T.trait_ref) :
+ trait_ref =
+ let trait_id = translate_trait_instance_id translate_ty tr.trait_id in
+ let generics = translate_generic_args translate_ty tr.generics in
+ let trait_decl_ref =
+ translate_trait_decl_ref translate_ty tr.trait_decl_ref
+ in
+ { trait_id; generics; trait_decl_ref }
+
+and translate_trait_decl_ref (translate_ty : 'r T.ty -> ty)
+ (tr : 'r T.trait_decl_ref) : trait_decl_ref =
+ let decl_generics = translate_generic_args translate_ty tr.decl_generics in
+ { trait_decl_id = tr.trait_decl_id; decl_generics }
+
+and translate_trait_instance_id (translate_ty : 'r T.ty -> ty)
+ (id : 'r T.trait_instance_id) : trait_instance_id =
+ let translate_trait_instance_id = translate_trait_instance_id translate_ty in
+ match id with
+ | T.Self -> Self
+ | TraitImpl id -> TraitImpl id
+ | BuiltinOrAuto _ ->
+ (* We should have eliminated those in the prepasses *)
+ raise (Failure "Unreachable")
+ | Clause id -> Clause id
+ | ParentClause (inst_id, decl_id, clause_id) ->
+ let inst_id = translate_trait_instance_id inst_id in
+ ParentClause (inst_id, decl_id, clause_id)
+ | ItemClause (inst_id, decl_id, item_name, clause_id) ->
+ let inst_id = translate_trait_instance_id inst_id in
+ ItemClause (inst_id, decl_id, item_name, clause_id)
+ | TraitRef tr -> TraitRef (translate_trait_ref translate_ty tr)
+ | FnPointer _ -> raise (Failure "TODO")
+ | UnknownTrait s -> raise (Failure ("Unknown trait found: " ^ s))
let rec translate_sty (ty : T.sty) : ty =
let translate = translate_sty in
match ty with
- | T.Adt (type_id, regions, tys, cgs) -> (
- (* Can't translate types with regions for now *)
- assert (regions = []);
- let tys = List.map translate tys in
+ | T.Adt (type_id, generics) -> (
+ let generics = translate_sgeneric_args generics in
match type_id with
- | T.AdtId adt_id -> Adt (AdtId adt_id, tys, cgs)
- | T.Tuple -> mk_simpl_tuple_ty tys
+ | T.AdtId adt_id -> Adt (AdtId adt_id, generics)
+ | T.Tuple ->
+ assert (generics.const_generics = []);
+ mk_simpl_tuple_ty generics.types
| T.Assumed aty -> (
match aty with
- | T.Vec -> Adt (Assumed Vec, tys, cgs)
- | T.Option -> Adt (Assumed Option, tys, cgs)
| T.Box -> (
(* Eliminate the boxes *)
- match tys with
+ match generics.types with
| [ ty ] -> ty
| _ ->
raise
(Failure
"Box/vec/option type with incorrect number of arguments")
)
- | T.Array -> Adt (Assumed Array, tys, cgs)
- | T.Slice -> Adt (Assumed Slice, tys, cgs)
- | T.Str -> Adt (Assumed Str, tys, cgs)
- | T.Range -> Adt (Assumed Range, tys, cgs)))
+ | T.Array -> Adt (Assumed Array, generics)
+ | T.Slice -> Adt (Assumed Slice, generics)
+ | T.Str -> Adt (Assumed Str, generics)))
| TypeVar vid -> TypeVar vid
| Literal ty -> Literal ty
| Never -> raise (Failure "Unreachable")
| Ref (_, rty, _) -> translate rty
+ | RawPtr (ty, rkind) ->
+ let mut = match rkind with Mut -> Mut | Shared -> Const in
+ let ty = translate ty in
+ let generics = { types = [ ty ]; const_generics = []; trait_refs = [] } in
+ Adt (Assumed (RawPtr mut), generics)
+ | TraitType (trait_ref, generics, type_name) ->
+ let trait_ref = translate_strait_ref trait_ref in
+ let generics = translate_sgeneric_args generics in
+ TraitType (trait_ref, generics, type_name)
+ | Arrow _ -> raise (Failure "TODO")
+
+and translate_sgeneric_args (generics : T.sgeneric_args) : generic_args =
+ translate_generic_args translate_sty generics
+
+and translate_strait_ref (tr : T.strait_ref) : trait_ref =
+ translate_trait_ref translate_sty tr
+
+and translate_strait_instance_id (id : T.strait_instance_id) : trait_instance_id
+ =
+ translate_trait_instance_id translate_sty id
+
+let translate_trait_clause (clause : T.trait_clause) : trait_clause =
+ let { T.clause_id; meta = _; trait_id; generics } = clause in
+ let generics = translate_sgeneric_args generics in
+ { clause_id; trait_id; generics }
+
+let translate_strait_type_constraint (ttc : T.strait_type_constraint) :
+ trait_type_constraint =
+ let { T.trait_ref; generics; type_name; ty } = ttc in
+ let trait_ref = translate_strait_ref trait_ref in
+ let generics = translate_sgeneric_args generics in
+ let ty = translate_sty ty in
+ { trait_ref; generics; type_name; ty }
+
+let translate_predicates (preds : T.predicates) : predicates =
+ let trait_type_constraints =
+ List.map translate_strait_type_constraint preds.trait_type_constraints
+ in
+ { trait_type_constraints }
+
+let translate_generic_params (generics : T.generic_params) : generic_params =
+ let { T.regions = _; types; const_generics; trait_clauses } = generics in
+ let trait_clauses = List.map translate_trait_clause trait_clauses in
+ { types; const_generics; trait_clauses }
let translate_field (f : T.field) : field =
let field_name = f.field_name in
@@ -452,15 +495,16 @@ let translate_type_decl_kind (kind : T.type_decl_kind) : type_decl_kind =
point of moving this definition for now.
*)
let translate_type_decl (def : T.type_decl) : type_decl =
- (* Translate *)
let def_id = def.T.def_id in
let name = def.name in
+ let { T.regions; types; const_generics; trait_clauses } = def.generics in
(* Can't translate types with regions for now *)
- assert (def.region_params = []);
- let type_params = def.type_params in
- let const_generic_params = def.const_generic_params in
+ assert (regions = []);
+ let trait_clauses = List.map translate_trait_clause trait_clauses in
+ let generics = { types; const_generics; trait_clauses } in
let kind = translate_type_decl_kind def.T.kind in
- { def_id; name; type_params; const_generic_params; kind }
+ let preds = translate_predicates def.preds in
+ { def_id; name; generics; kind; preds }
let translate_type_id (id : T.type_id) : type_id =
match id with
@@ -468,12 +512,9 @@ let translate_type_id (id : T.type_id) : type_id =
| T.Assumed aty ->
let aty =
match aty with
- | T.Vec -> Vec
- | T.Option -> Option
| T.Array -> Array
| T.Slice -> Slice
| T.Str -> Str
- | T.Range -> Range
| T.Box ->
(* Boxes have to be eliminated: this type id shouldn't
be translated *)
@@ -488,28 +529,26 @@ let translate_type_id (id : T.type_id) : type_id =
let rec translate_fwd_ty (type_infos : TA.type_infos) (ty : 'r T.ty) : ty =
let translate = translate_fwd_ty type_infos in
match ty with
- | T.Adt (type_id, regions, tys, cgs) -> (
- (* Can't translate types with regions for now *)
- assert (regions = []);
- (* Translate the type parameters *)
- let t_tys = List.map translate tys in
+ | T.Adt (type_id, generics) -> (
+ let t_generics = translate_fwd_generic_args type_infos generics in
(* Eliminate boxes and simplify tuples *)
match type_id with
- | AdtId _
- | T.Assumed (T.Vec | T.Option | T.Array | T.Slice | T.Str | T.Range) ->
- (* No general parametricity for now *)
- assert (not (List.exists (TypesUtils.ty_has_borrows type_infos) tys));
+ | AdtId _ | T.Assumed (T.Array | T.Slice | T.Str) ->
let type_id = translate_type_id type_id in
- Adt (type_id, t_tys, cgs)
+ Adt (type_id, t_generics)
| Tuple ->
(* Note that if there is exactly one type, [mk_simpl_tuple_ty] is the
identity *)
- mk_simpl_tuple_ty t_tys
+ mk_simpl_tuple_ty t_generics.types
| T.Assumed T.Box -> (
(* We eliminate boxes *)
(* No general parametricity for now *)
- assert (not (List.exists (TypesUtils.ty_has_borrows type_infos) tys));
- match t_tys with
+ assert (
+ not
+ (List.exists
+ (TypesUtils.ty_has_borrows type_infos)
+ generics.types));
+ match t_generics.types with
| [ bty ] -> bty
| _ ->
raise
@@ -520,12 +559,40 @@ let rec translate_fwd_ty (type_infos : TA.type_infos) (ty : 'r T.ty) : ty =
| Never -> raise (Failure "Unreachable")
| Literal lty -> Literal lty
| Ref (_, rty, _) -> translate rty
+ | RawPtr (ty, rkind) ->
+ let mut = match rkind with Mut -> Mut | Shared -> Const in
+ let ty = translate ty in
+ let generics = { types = [ ty ]; const_generics = []; trait_refs = [] } in
+ Adt (Assumed (RawPtr mut), generics)
+ | TraitType (trait_ref, generics, type_name) ->
+ let trait_ref = translate_fwd_trait_ref type_infos trait_ref in
+ let generics = translate_fwd_generic_args type_infos generics in
+ TraitType (trait_ref, generics, type_name)
+ | Arrow _ -> raise (Failure "TODO")
+
+and translate_fwd_generic_args (type_infos : TA.type_infos)
+ (generics : 'r T.generic_args) : generic_args =
+ translate_generic_args (translate_fwd_ty type_infos) generics
+
+and translate_fwd_trait_ref (type_infos : TA.type_infos) (tr : 'r T.trait_ref) :
+ trait_ref =
+ translate_trait_ref (translate_fwd_ty type_infos) tr
+
+and translate_fwd_trait_instance_id (type_infos : TA.type_infos)
+ (id : 'r T.trait_instance_id) : trait_instance_id =
+ translate_trait_instance_id (translate_fwd_ty type_infos) id
(** Simply calls [translate_fwd_ty] *)
let ctx_translate_fwd_ty (ctx : bs_ctx) (ty : 'r T.ty) : ty =
let type_infos = ctx.type_context.type_infos in
translate_fwd_ty type_infos ty
+(** Simply calls [translate_fwd_generic_args] *)
+let ctx_translate_fwd_generic_args (ctx : bs_ctx) (generics : 'r T.generic_args)
+ : generic_args =
+ let type_infos = ctx.type_context.type_infos in
+ translate_fwd_generic_args type_infos generics
+
(** Translate a type, when some regions may have ended.
We return an option, because the translated type may be empty.
@@ -538,30 +605,40 @@ let rec translate_back_ty (type_infos : TA.type_infos)
(* A small helper for "leave" types *)
let wrap ty = if inside_mut then Some ty else None in
match ty with
- | T.Adt (type_id, _, tys, cgs) -> (
+ | T.Adt (type_id, generics) -> (
match type_id with
- | T.AdtId _
- | Assumed (T.Vec | T.Option | T.Array | T.Slice | T.Str | T.Range) ->
- (* Don't accept ADTs (which are not tuples) with borrows for now *)
- assert (not (TypesUtils.ty_has_borrows type_infos ty));
+ | T.AdtId _ | Assumed (T.Array | T.Slice | T.Str) ->
let type_id = translate_type_id type_id in
if inside_mut then
- let tys_t = List.filter_map translate tys in
- Some (Adt (type_id, tys_t, cgs))
- else None
+ (* We do not want to filter anything, so we translate the generics
+ as "forward" types *)
+ let generics = translate_fwd_generic_args type_infos generics in
+ Some (Adt (type_id, generics))
+ else
+ (* If not inside a mutable reference: check if at least one
+ of the generics contains a mutable reference (i.e., is not
+ translated to `None`. If yes, keep the whole type, and
+ translate all the generics as "forward" types (the backward
+ function will extract the proper information from the ADT value)
+ *)
+ let types = List.filter_map translate generics.types in
+ if types <> [] then
+ let generics = translate_fwd_generic_args type_infos generics in
+ Some (Adt (type_id, generics))
+ else None
| Assumed T.Box -> (
(* Don't accept ADTs (which are not tuples) with borrows for now *)
assert (not (TypesUtils.ty_has_borrows type_infos ty));
(* Eliminate the box *)
- match tys with
+ match generics.types with
| [ bty ] -> translate bty
| _ ->
raise
(Failure "Unreachable: boxes receive exactly one type parameter")
)
| T.Tuple -> (
- (* Tuples can contain borrows (which we eliminated) *)
- let tys_t = List.filter_map translate tys in
+ (* Tuples can contain borrows (which we eliminate) *)
+ let tys_t = List.filter_map translate generics.types in
match tys_t with
| [] -> None
| _ ->
@@ -582,6 +659,17 @@ let rec translate_back_ty (type_infos : TA.type_infos)
if keep_region r then
translate_back_ty type_infos keep_region inside_mut rty
else None)
+ | RawPtr _ ->
+ (* TODO: not sure what to do here *)
+ None
+ | TraitType (trait_ref, generics, type_name) ->
+ assert (generics.regions = []);
+ (* Translate the trait ref and the generics as "forward" generics -
+ we do not want to filter any type *)
+ let trait_ref = translate_fwd_trait_ref type_infos trait_ref in
+ let generics = translate_fwd_generic_args type_infos generics in
+ Some (TraitType (trait_ref, generics, type_name))
+ | Arrow _ -> raise (Failure "TODO")
(** Simply calls [translate_back_ty] *)
let ctx_translate_back_ty (ctx : bs_ctx) (keep_region : 'r -> bool)
@@ -589,6 +677,80 @@ let ctx_translate_back_ty (ctx : bs_ctx) (keep_region : 'r -> bool)
let type_infos = ctx.type_context.type_infos in
translate_back_ty type_infos keep_region inside_mut ty
+let mk_type_check_ctx (ctx : bs_ctx) : PureTypeCheck.tc_ctx =
+ let const_generics =
+ T.ConstGenericVarId.Map.of_list
+ (List.map
+ (fun (cg : T.const_generic_var) ->
+ (cg.index, ctx_translate_fwd_ty ctx (T.Literal cg.ty)))
+ ctx.sg.generics.const_generics)
+ in
+ let env = VarId.Map.empty in
+ {
+ PureTypeCheck.type_decls = ctx.type_context.type_decls;
+ global_decls = ctx.global_context.llbc_global_decls;
+ env;
+ const_generics;
+ }
+
+let type_check_pattern (ctx : bs_ctx) (v : typed_pattern) : unit =
+ let ctx = mk_type_check_ctx ctx in
+ let _ = PureTypeCheck.check_typed_pattern ctx v in
+ ()
+
+let type_check_texpression (ctx : bs_ctx) (e : texpression) : unit =
+ if !Config.type_check_pure_code then
+ let ctx = mk_type_check_ctx ctx in
+ PureTypeCheck.check_texpression ctx e
+
+let translate_fun_id_or_trait_method_ref (ctx : bs_ctx)
+ (id : A.fun_id_or_trait_method_ref) : fun_id_or_trait_method_ref =
+ match id with
+ | FunId fun_id -> FunId fun_id
+ | TraitMethod (trait_ref, method_name, fun_decl_id) ->
+ let type_infos = ctx.type_context.type_infos in
+ let trait_ref = translate_fwd_trait_ref type_infos trait_ref in
+ TraitMethod (trait_ref, method_name, fun_decl_id)
+
+let bs_ctx_register_forward_call (call_id : V.FunCallId.id) (forward : S.call)
+ (args : texpression list) (ctx : bs_ctx) : bs_ctx =
+ let calls = ctx.calls in
+ assert (not (V.FunCallId.Map.mem call_id calls));
+ let info =
+ { forward; forward_inputs = args; backwards = T.RegionGroupId.Map.empty }
+ in
+ let calls = V.FunCallId.Map.add call_id info calls in
+ { ctx with calls }
+
+(** [back_args]: the *additional* list of inputs received by the backward function *)
+let bs_ctx_register_backward_call (abs : V.abs) (call_id : V.FunCallId.id)
+ (back_id : T.RegionGroupId.id) (back_args : texpression list) (ctx : bs_ctx)
+ : bs_ctx * fun_or_op_id =
+ (* Insert the abstraction in the call informations *)
+ let info = V.FunCallId.Map.find call_id ctx.calls in
+ assert (not (T.RegionGroupId.Map.mem back_id info.backwards));
+ let backwards =
+ T.RegionGroupId.Map.add back_id (abs, back_args) info.backwards
+ in
+ let info = { info with backwards } in
+ let calls = V.FunCallId.Map.add call_id info ctx.calls in
+ (* Insert the abstraction in the abstractions map *)
+ let abstractions = ctx.abstractions in
+ assert (not (V.AbstractionId.Map.mem abs.abs_id abstractions));
+ let abstractions =
+ V.AbstractionId.Map.add abs.abs_id (abs, back_args) abstractions
+ in
+ (* Retrieve the fun_id *)
+ let fun_id =
+ match info.forward.call_id with
+ | S.Fun (fid, _) ->
+ let fid = translate_fun_id_or_trait_method_ref ctx fid in
+ Fun (FromLlbc (fid, None, Some back_id))
+ | S.Unop _ | S.Binop _ -> raise (Failure "Unreachable")
+ in
+ (* Update the context and return *)
+ ({ ctx with calls; abstractions }, fun_id)
+
(** List the ancestors of an abstraction *)
let list_ancestor_abstractions_ids (ctx : bs_ctx) (abs : V.abs)
(call_id : V.FunCallId.id) : V.AbstractionId.id list =
@@ -642,10 +804,10 @@ let mk_fuel_input_as_list (ctx : bs_ctx) (info : fun_effect_info) :
(** Small utility. *)
let get_fun_effect_info (fun_infos : FA.fun_info A.FunDeclId.Map.t)
- (fun_id : A.fun_id) (lid : V.LoopId.id option)
+ (fun_id : A.fun_id_or_trait_method_ref) (lid : V.LoopId.id option)
(gid : T.RegionGroupId.id option) : fun_effect_info =
match fun_id with
- | A.Regular fid ->
+ | TraitMethod (_, _, fid) | FunId (Regular fid) ->
let info = A.FunDeclId.Map.find fid fun_infos in
let stateful_group = info.stateful in
let stateful =
@@ -658,10 +820,10 @@ let get_fun_effect_info (fun_infos : FA.fun_info A.FunDeclId.Map.t)
can_diverge = info.can_diverge;
is_rec = info.is_rec || Option.is_some lid;
}
- | A.Assumed aid ->
+ | FunId (Assumed aid) ->
assert (lid = None);
{
- can_fail = Assumed.assumed_can_fail aid;
+ can_fail = Assumed.assumed_fun_can_fail aid;
stateful_group = false;
stateful = false;
can_diverge = false;
@@ -673,12 +835,14 @@ let get_fun_effect_info (fun_infos : FA.fun_info A.FunDeclId.Map.t)
Note that the function also takes a list of names for the inputs, and
computes, for every output for the backward functions, a corresponding
name (outputs for backward functions come from borrows in the inputs
- of the forward function) which we use as hints to generate pretty names.
+ of the forward function) which we use as hints to generate pretty names
+ in the extracted code.
*)
-let translate_fun_sig (fun_infos : FA.fun_info A.FunDeclId.Map.t)
- (fun_id : A.fun_id) (type_infos : TA.type_infos) (sg : A.fun_sig)
- (input_names : string option list) (bid : T.RegionGroupId.id option) :
- fun_sig_named_outputs =
+let translate_fun_sig (decls_ctx : C.decls_ctx) (fun_id : A.fun_id)
+ (sg : A.fun_sig) (input_names : string option list)
+ (bid : T.RegionGroupId.id option) : fun_sig_named_outputs =
+ let fun_infos = decls_ctx.fun_ctx.fun_infos in
+ let type_infos = decls_ctx.type_ctx.type_infos in
(* Retrieve the list of parent backward functions *)
let gid, parents =
match bid with
@@ -689,7 +853,34 @@ let translate_fun_sig (fun_infos : FA.fun_info A.FunDeclId.Map.t)
in
(* Is the function stateful, and can it fail? *)
let lid = None in
- let effect_info = get_fun_effect_info fun_infos fun_id lid bid in
+ let effect_info = get_fun_effect_info fun_infos (FunId fun_id) lid bid in
+ (* We need an evaluation context to normalize the types (to normalize the
+ associated types, etc. - for instance it may happen that the types
+ refer to the types associated to a trait ref, but where the trait ref
+ is a known impl). *)
+ (* Create the context *)
+ let ctx =
+ let region_groups =
+ List.map (fun (g : T.region_var_group) -> g.id) sg.regions_hierarchy
+ in
+ let ctx =
+ InterpreterUtils.initialize_eval_context decls_ctx region_groups
+ sg.generics.types sg.generics.const_generics
+ in
+ (* Compute the normalization map for the *sty* types and add it to the context *)
+ AssociatedTypes.ctx_add_norm_trait_stypes_from_preds ctx
+ sg.preds.trait_type_constraints
+ in
+
+ (* Normalize the signature *)
+ let sg =
+ let ({ A.inputs; output; _ } : A.fun_sig) = sg in
+ let norm = AssociatedTypes.ctx_normalize_sty ctx in
+ let inputs = List.map norm inputs in
+ let output = norm output in
+ { sg with A.inputs; output }
+ in
+
(* List the inputs for:
* - the fuel
* - the forward function
@@ -806,9 +997,8 @@ let translate_fun_sig (fun_infos : FA.fun_info A.FunDeclId.Map.t)
(* Wrap in a result type *)
if effect_info.can_fail then mk_result_ty output else output
in
- (* Type/const generic parameters *)
- let type_params = sg.type_params in
- let const_generic_params = sg.const_generic_params in
+ (* Generic parameters *)
+ let generics = translate_generic_params sg.generics in
(* Return *)
let has_fuel = fuel <> [] in
let num_fwd_inputs_no_state = List.length fwd_inputs in
@@ -836,9 +1026,8 @@ let translate_fun_sig (fun_infos : FA.fun_info A.FunDeclId.Map.t)
effect_info;
}
in
- let sg =
- { type_params; const_generic_params; inputs; output; doutputs; info }
- in
+ let preds = translate_predicates sg.A.preds in
+ let sg = { generics; preds; inputs; output; doutputs; info } in
{ sg; output_names }
let bs_ctx_fresh_state_var (ctx : bs_ctx) : bs_ctx * typed_pattern =
@@ -917,7 +1106,7 @@ let lookup_var_for_symbolic_value (sv : V.symbolic_value) (ctx : bs_ctx) : var =
(** Peel boxes as long as the value is of the form [Box<T>] *)
let rec unbox_typed_value (v : V.typed_value) : V.typed_value =
match (v.value, v.ty) with
- | V.Adt av, T.Adt (T.Assumed T.Box, _, _, _) -> (
+ | V.Adt av, T.Adt (T.Assumed T.Box, _) -> (
match av.field_values with
| [ bv ] -> unbox_typed_value bv
| _ -> raise (Failure "Unreachable"))
@@ -962,16 +1151,16 @@ let rec typed_value_to_texpression (ctx : bs_ctx) (ectx : C.eval_ctx)
let field_values = List.map translate av.field_values in
(* Eliminate the tuple wrapper if it is a tuple with exactly one field *)
match v.ty with
- | T.Adt (T.Tuple, _, _, _) ->
+ | T.Adt (T.Tuple, _) ->
assert (variant_id = None);
mk_simpl_tuple_texpression field_values
| _ ->
- (* Retrieve the type, the translated type arguments and the
- * const generic arguments from the translated type (simpler this way) *)
- let adt_id, type_args, const_generic_args = ty_as_adt ty in
+ (* Retrieve the type and the translated generics from the translated
+ type (simpler this way) *)
+ let adt_id, generics = ty_as_adt ty in
(* Create the constructor *)
let qualif_id = AdtCons { adt_id; variant_id = av.variant_id } in
- let qualif = { id = qualif_id; type_args; const_generic_args } in
+ let qualif = { id = qualif_id; generics } in
let cons_e = Qualif qualif in
let field_tys =
List.map (fun (v : texpression) -> v.ty) field_values
@@ -1038,11 +1227,9 @@ let rec typed_avalue_to_consumed (ctx : bs_ctx) (ectx : C.eval_ctx)
(* Translate the field values *)
let field_values = List.filter_map translate adt_v.field_values in
(* For now, only tuples can contain borrows *)
- let adt_id, _, _, _ = TypesUtils.ty_as_adt av.ty in
+ let adt_id, _ = TypesUtils.ty_as_adt av.ty in
match adt_id with
- | T.AdtId _
- | T.Assumed
- (T.Box | T.Vec | T.Option | T.Array | T.Slice | T.Str | T.Range) ->
+ | T.AdtId _ | T.Assumed (T.Box | T.Array | T.Slice | T.Str) ->
assert (field_values = []);
None
| T.Tuple ->
@@ -1185,11 +1372,9 @@ let rec typed_avalue_to_given_back (mp : mplace option) (av : V.typed_avalue)
(* For now, only tuples can contain borrows - note that if we gave
* something like a [&mut Vec] to a function, we give back the
* vector value upon visiting the "abstraction borrow" node *)
- let adt_id, _, _, _ = TypesUtils.ty_as_adt av.ty in
+ let adt_id, _ = TypesUtils.ty_as_adt av.ty in
match adt_id with
- | T.AdtId _
- | T.Assumed
- (T.Box | T.Vec | T.Option | T.Array | T.Slice | T.Str | T.Range) ->
+ | T.AdtId _ | T.Assumed (T.Box | T.Array | T.Slice | T.Str) ->
assert (field_values = []);
(ctx, None)
| T.Tuple ->
@@ -1457,9 +1642,12 @@ and translate_return_with_loop (loop_id : V.LoopId.id) (is_continue : bool)
and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) :
texpression =
+ log#ldebug
+ (lazy
+ ("translate_function_call:\n"
+ ^ ctx_egeneric_args_to_string ctx call.generics));
(* Translate the function call *)
- let type_args = List.map (ctx_translate_fwd_ty ctx) call.type_params in
- let const_generic_args = call.const_generic_params in
+ let generics = ctx_translate_fwd_generic_args ctx call.generics in
let args =
let args = List.map (typed_value_to_texpression ctx call.ctx) call.args in
let args_mplaces = List.map translate_opt_mplace call.args_places in
@@ -1475,7 +1663,8 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) :
match call.call_id with
| S.Fun (fid, call_id) ->
(* Regular function call *)
- let func = Fun (FromLlbc (fid, None, None)) in
+ let fid_t = translate_fun_id_or_trait_method_ref ctx fid in
+ let func = Fun (FromLlbc (fid_t, None, None)) in
(* Retrieve the effect information about this function (can fail,
* takes a state as input, etc.) *)
let effect_info =
@@ -1525,18 +1714,20 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) :
in
(ctx, Unop (Neg int_ty), effect_info, args, None)
| _ -> raise (Failure "Unreachable"))
- | S.Unop (E.Cast (src_ty, tgt_ty)) ->
- (* Note that cast can fail *)
- let effect_info =
- {
- can_fail = true;
- stateful_group = false;
- stateful = false;
- can_diverge = false;
- is_rec = false;
- }
- in
- (ctx, Unop (Cast (src_ty, tgt_ty)), effect_info, args, None)
+ | S.Unop (E.Cast cast_kind) -> (
+ match cast_kind with
+ | CastInteger (src_ty, tgt_ty) ->
+ (* Note that cast can fail *)
+ let effect_info =
+ {
+ can_fail = true;
+ stateful_group = false;
+ stateful = false;
+ can_diverge = false;
+ is_rec = false;
+ }
+ in
+ (ctx, Unop (Cast (src_ty, tgt_ty)), effect_info, args, None))
| S.Binop binop -> (
match args with
| [ arg0; arg1 ] ->
@@ -1561,7 +1752,7 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) :
| None -> dest
| Some out_state -> mk_simpl_tuple_pattern [ out_state; dest ]
in
- let func = { id = FunOrOp fun_id; type_args; const_generic_args } in
+ let func = { id = FunOrOp fun_id; generics } in
let input_tys = (List.map (fun (x : texpression) -> x.ty)) args in
let ret_ty =
if effect_info.can_fail then mk_result_ty dest_v.ty else dest_v.ty
@@ -1665,9 +1856,11 @@ and translate_end_abstraction_synth_input (ectx : C.eval_ctx) (abs : V.abs)
(* Group the two lists *)
let variables_values = List.combine given_back_variables consumed_values in
(* Sanity check: the two lists match (same types) *)
- List.iter
- (fun (var, v) -> assert ((var : var).ty = (v : texpression).ty))
- variables_values;
+ (* TODO: normalize the types *)
+ if !Config.type_check_pure_code then
+ List.iter
+ (fun (var, v) -> assert ((var : var).ty = (v : texpression).ty))
+ variables_values;
(* Translate the next expression *)
let next_e = translate_expression e ctx in
(* Generate the assignemnts *)
@@ -1692,8 +1885,7 @@ and translate_end_abstraction_fun_call (ectx : C.eval_ctx) (abs : V.abs)
let effect_info =
get_fun_effect_info ctx.fun_context.fun_infos fun_id None (Some rg_id)
in
- let type_args = List.map (ctx_translate_fwd_ty ctx) call.type_params in
- let const_generic_args = call.const_generic_params in
+ let generics = ctx_translate_fwd_generic_args ctx call.generics in
(* Retrieve the original call and the parent abstractions *)
let _forward, backwards = get_abs_ancestors ctx abs call_id in
(* Retrieve the values consumed when we called the forward function and
@@ -1741,34 +1933,35 @@ and translate_end_abstraction_fun_call (ectx : C.eval_ctx) (abs : V.abs)
| Some nstate -> mk_simpl_tuple_pattern [ nstate; output ]
in
(* Sanity check: there is the proper number of inputs and outputs, and they have the proper type *)
- let _ =
- let inst_sg =
- get_instantiated_fun_sig fun_id (Some rg_id) type_args const_generic_args
- ctx
- in
- log#ldebug
- (lazy
- ("\n- fun_id: " ^ A.show_fun_id fun_id ^ "\n- inputs ("
- ^ string_of_int (List.length inputs)
- ^ "): "
- ^ String.concat ", " (List.map (texpression_to_string ctx) inputs)
- ^ "\n- inst_sg.inputs ("
- ^ string_of_int (List.length inst_sg.inputs)
- ^ "): "
- ^ String.concat ", " (List.map (ty_to_string ctx) inst_sg.inputs)));
- List.iter
- (fun (x, ty) -> assert ((x : texpression).ty = ty))
- (List.combine inputs inst_sg.inputs);
- log#ldebug
- (lazy
- ("\n- outputs: "
- ^ string_of_int (List.length outputs)
- ^ "\n- expected outputs: "
- ^ string_of_int (List.length inst_sg.doutputs)));
- List.iter
- (fun (x, ty) -> assert ((x : typed_pattern).ty = ty))
- (List.combine outputs inst_sg.doutputs)
- in
+ (if (* TODO: normalize the types *) !Config.type_check_pure_code then
+ match fun_id with
+ | FunId fun_id ->
+ let inst_sg =
+ get_instantiated_fun_sig fun_id (Some rg_id) generics ctx
+ in
+ log#ldebug
+ (lazy
+ ("\n- fun_id: " ^ A.show_fun_id fun_id ^ "\n- inputs ("
+ ^ string_of_int (List.length inputs)
+ ^ "): "
+ ^ String.concat ", " (List.map (texpression_to_string ctx) inputs)
+ ^ "\n- inst_sg.inputs ("
+ ^ string_of_int (List.length inst_sg.inputs)
+ ^ "): "
+ ^ String.concat ", " (List.map (ty_to_string ctx) inst_sg.inputs)));
+ List.iter
+ (fun (x, ty) -> assert ((x : texpression).ty = ty))
+ (List.combine inputs inst_sg.inputs);
+ log#ldebug
+ (lazy
+ ("\n- outputs: "
+ ^ string_of_int (List.length outputs)
+ ^ "\n- expected outputs: "
+ ^ string_of_int (List.length inst_sg.doutputs)));
+ List.iter
+ (fun (x, ty) -> assert ((x : typed_pattern).ty = ty))
+ (List.combine outputs inst_sg.doutputs)
+ | _ -> (* TODO: trait methods *) ());
(* Retrieve the function id, and register the function call in the context
* if necessary *)
let ctx, func =
@@ -1788,7 +1981,7 @@ and translate_end_abstraction_fun_call (ectx : C.eval_ctx) (abs : V.abs)
if effect_info.can_fail then mk_result_ty output.ty else output.ty
in
let func_ty = mk_arrows input_tys ret_ty in
- let func = { id = FunOrOp func; type_args; const_generic_args } in
+ let func = { id = FunOrOp func; generics } in
let func = { e = Qualif func; ty = func_ty } in
let call = mk_apps func args in
(* **Optimization**:
@@ -1905,14 +2098,13 @@ and translate_end_abstraction_loop (ectx : C.eval_ctx) (abs : V.abs)
(* Actually the same case as [SynthInput] *)
translate_end_abstraction_synth_input ectx abs e ctx rg_id
| V.LoopCall ->
- let fun_id = A.Regular ctx.fun_decl.A.def_id in
+ let fun_id = E.Regular ctx.fun_decl.A.def_id in
let effect_info =
- get_fun_effect_info ctx.fun_context.fun_infos fun_id (Some vloop_id)
- (Some rg_id)
+ get_fun_effect_info ctx.fun_context.fun_infos (FunId fun_id)
+ (Some vloop_id) (Some rg_id)
in
let loop_info = LoopId.Map.find loop_id ctx.loops in
- let type_args = loop_info.type_args in
- let const_generic_args = loop_info.const_generic_args in
+ let generics = loop_info.generics in
let fwd_inputs = Option.get loop_info.forward_inputs in
(* Retrieve the additional backward inputs. Note that those are actually
the backward inputs of the function we are synthesizing (and that we
@@ -1960,8 +2152,8 @@ and translate_end_abstraction_loop (ectx : C.eval_ctx) (abs : V.abs)
if effect_info.can_fail then mk_result_ty output.ty else output.ty
in
let func_ty = mk_arrows input_tys ret_ty in
- let func = Fun (FromLlbc (fun_id, Some loop_id, Some rg_id)) in
- let func = { id = FunOrOp func; type_args; const_generic_args } in
+ let func = Fun (FromLlbc (FunId fun_id, Some loop_id, Some rg_id)) in
+ let func = { id = FunOrOp func; generics } in
let func = { e = Qualif func; ty = func_ty } in
let call = mk_apps func args in
(* **Optimization**:
@@ -2021,9 +2213,7 @@ and translate_global_eval (gid : A.GlobalDeclId.id) (sval : V.symbolic_value)
(e : S.expression) (ctx : bs_ctx) : texpression =
let ctx, var = fresh_var_for_symbolic_value sval ctx in
let decl = A.GlobalDeclId.Map.find gid ctx.global_context.llbc_global_decls in
- let global_expr =
- { id = Global gid; type_args = []; const_generic_args = [] }
- in
+ let global_expr = { id = Global gid; generics = empty_generic_args } in
(* We use translate_fwd_ty to translate the global type *)
let ty = ctx_translate_fwd_ty ctx decl.ty in
let gval = { e = Qualif global_expr; ty } in
@@ -2037,11 +2227,7 @@ and translate_assertion (ectx : C.eval_ctx) (v : V.typed_value)
let v = typed_value_to_texpression ctx ectx v in
let args = [ v ] in
let func =
- {
- id = FunOrOp (Fun (Pure Assert));
- type_args = [];
- const_generic_args = [];
- }
+ { id = FunOrOp (Fun (Pure Assert)); generics = empty_generic_args }
in
let func_ty = mk_arrow (Literal Bool) mk_unit_ty in
let func = { e = Qualif func; ty = func_ty } in
@@ -2189,7 +2375,7 @@ and translate_ExpandAdt_one_branch (sv : V.symbolic_value)
(branch : S.expression) (ctx : bs_ctx) : texpression =
(* TODO: always introduce a match, and use micro-passes to turn the
the match into a let? *)
- let type_id, _, _, _ = TypesUtils.ty_as_adt sv.V.sv_ty in
+ let type_id, _ = TypesUtils.ty_as_adt sv.V.sv_ty in
let ctx, vars = fresh_vars_for_symbolic_values svl ctx in
let branch = translate_expression branch ctx in
match type_id with
@@ -2224,10 +2410,10 @@ and translate_ExpandAdt_one_branch (sv : V.symbolic_value)
* field.
* We use the [dest] variable in order not to have to recompute
* the type of the result of the projection... *)
- let adt_id, type_args, const_generic_args = ty_as_adt scrutinee.ty in
+ let adt_id, generics = ty_as_adt scrutinee.ty in
let gen_field_proj (field_id : FieldId.id) (dest : var) : texpression =
let proj_kind = { adt_id; field_id } in
- let qualif = { id = Proj proj_kind; type_args; const_generic_args } in
+ let qualif = { id = Proj proj_kind; generics } in
let proj_e = Qualif qualif in
let proj_ty = mk_arrow scrutinee.ty dest.ty in
let proj = { e = proj_e; ty = proj_ty } in
@@ -2259,17 +2445,12 @@ and translate_ExpandAdt_one_branch (sv : V.symbolic_value)
(mk_typed_pattern_from_var var None)
(mk_opt_mplace_texpression scrutinee_mplace scrutinee)
branch
- | T.Assumed (T.Vec | T.Array | T.Slice | T.Str) ->
+ | T.Assumed (T.Array | T.Slice | T.Str) ->
(* We can't expand those values: we can access the fields only
* through the functions provided by the API (note that we don't
* know how to expand values like vectors or arrays, because they have a variable number
* of fields!) *)
raise (Failure "Attempt to expand a non-expandable value")
- | T.Assumed Range -> raise (Failure "Unimplemented")
- | T.Assumed T.Option ->
- (* We shouldn't get there in the "one-branch" case: options have
- * two variants *)
- raise (Failure "Unreachable")
and translate_intro_symbolic (ectx : C.eval_ctx) (p : S.mplace option)
(sv : V.symbolic_value) (v : S.value_aggregate) (e : S.expression)
@@ -2282,8 +2463,9 @@ and translate_intro_symbolic (ectx : C.eval_ctx) (p : S.mplace option)
(* Translate the next expression *)
let next_e = translate_expression e ctx in
- (* Translate the value: there are two cases, depending on whether this
- is a "regular" let-binding or an array aggregate.
+ (* Translate the value: there are several cases, depending on whether this
+ is a "regular" let-binding, an array aggregate, a const generic or
+ a trait associated constant.
*)
let v =
match v with
@@ -2298,6 +2480,14 @@ and translate_intro_symbolic (ectx : C.eval_ctx) (p : S.mplace option)
{ struct_id = Assumed Array; init = None; updates = values }
in
{ e = StructUpdate su; ty = var.ty }
+ | ConstGenericValue cg_id -> { e = CVar cg_id; ty = var.ty }
+ | TraitConstValue (trait_ref, generics, const_name) ->
+ let type_infos = ctx.type_context.type_infos in
+ let trait_ref = translate_fwd_trait_ref type_infos trait_ref in
+ let generics = translate_fwd_generic_args type_infos generics in
+ let qualif_id = TraitConst (trait_ref, generics, const_name) in
+ let qualif = { id = qualif_id; generics = empty_generic_args } in
+ { e = Qualif qualif; ty = var.ty }
in
(* Make the let-binding *)
@@ -2368,9 +2558,9 @@ and translate_forward_end (ectx : C.eval_ctx)
let org_args = args in
(* Lookup the effect info for the loop function *)
- let fid = A.Regular ctx.fun_decl.A.def_id in
+ let fid = E.Regular ctx.fun_decl.A.def_id in
let effect_info =
- get_fun_effect_info ctx.fun_context.fun_infos fid None ctx.bid
+ get_fun_effect_info ctx.fun_context.fun_infos (FunId fid) None ctx.bid
in
(* Introduce a fresh output value for the forward function *)
@@ -2415,14 +2605,8 @@ and translate_forward_end (ectx : C.eval_ctx)
let out_pat = mk_simpl_tuple_pattern out_pats in
let loop_call =
- let fun_id = Fun (FromLlbc (fid, Some loop_id, None)) in
- let func =
- {
- id = FunOrOp fun_id;
- type_args = loop_info.type_args;
- const_generic_args = loop_info.const_generic_args;
- }
- in
+ let fun_id = Fun (FromLlbc (FunId fid, Some loop_id, None)) in
+ let func = { id = FunOrOp fun_id; generics = loop_info.generics } in
let input_tys = (List.map (fun (x : texpression) -> x.ty)) args in
let ret_ty =
if effect_info.can_fail then mk_result_ty out_pat.ty else out_pat.ty
@@ -2541,14 +2725,31 @@ and translate_loop (loop : S.loop) (ctx : bs_ctx) : texpression =
(* Note that we will retrieve the input values later in the [ForwardEnd]
(and will introduce the outputs at that moment, together with the actual
- call to the loop forward function *)
- let type_args =
- List.map (fun (ty : T.type_var) -> TypeVar ty.T.index) ctx.sg.type_params
- in
- let const_generic_args =
- List.map
- (fun (cg : T.const_generic_var) -> T.ConstGenericVar cg.T.index)
- ctx.sg.const_generic_params
+ call to the loop forward function) *)
+ let generics =
+ let { types; const_generics; trait_clauses } = ctx.sg.generics in
+ let types =
+ List.map (fun (ty : T.type_var) -> TypeVar ty.T.index) types
+ in
+ let const_generics =
+ List.map
+ (fun (cg : T.const_generic_var) -> T.ConstGenericVar cg.T.index)
+ const_generics
+ in
+ let trait_refs =
+ List.map
+ (fun (c : trait_clause) ->
+ let trait_decl_ref =
+ { trait_decl_id = c.trait_id; decl_generics = empty_generic_args }
+ in
+ {
+ trait_id = Clause c.clause_id;
+ generics = empty_generic_args;
+ trait_decl_ref;
+ })
+ trait_clauses
+ in
+ { types; const_generics; trait_refs }
in
let loop_info =
@@ -2556,8 +2757,7 @@ and translate_loop (loop : S.loop) (ctx : bs_ctx) : texpression =
loop_id;
input_vars = inputs;
input_svl = loop.input_svalues;
- type_args;
- const_generic_args;
+ generics;
forward_inputs = None;
forward_output_no_state_no_result = None;
}
@@ -2648,8 +2848,7 @@ let wrap_in_match_fuel (fuel0 : VarId.id) (fuel : VarId.id) (body : texpression)
let func =
{
id = FunOrOp (Fun (Pure FuelEqZero));
- type_args = [];
- const_generic_args = [];
+ generics = empty_generic_args;
}
in
let func_ty = mk_arrow mk_fuel_ty mk_bool_ty in
@@ -2661,8 +2860,7 @@ let wrap_in_match_fuel (fuel0 : VarId.id) (fuel : VarId.id) (body : texpression)
let func =
{
id = FunOrOp (Fun (Pure FuelDecrease));
- type_args = [];
- const_generic_args = [];
+ generics = empty_generic_args;
}
in
let func_ty = mk_arrow mk_fuel_ty mk_fuel_ty in
@@ -2727,8 +2925,8 @@ let translate_fun_decl (ctx : bs_ctx) (body : S.expression option) : fun_decl =
| None -> None
| Some body ->
let effect_info =
- get_fun_effect_info ctx.fun_context.fun_infos (Regular def_id) None
- bid
+ get_fun_effect_info ctx.fun_context.fun_infos (FunId (Regular def_id))
+ None bid
in
let body = translate_expression body ctx in
(* Add a match over the fuel, if necessary *)
@@ -2803,10 +3001,12 @@ let translate_fun_decl (ctx : bs_ctx) (body : S.expression option) : fun_decl =
^ "\n- signature.inputs: "
^ String.concat ", " (List.map (ty_to_string ctx) signature.inputs)
));
- assert (
- List.for_all
- (fun (var, ty) -> (var : var).ty = ty)
- (List.combine inputs signature.inputs));
+ (* TODO: we need to normalize the types *)
+ if !Config.type_check_pure_code then
+ assert (
+ List.for_all
+ (fun (var, ty) -> (var : var).ty = ty)
+ (List.combine inputs signature.inputs));
Some { inputs; inputs_lvs; body }
in
@@ -2821,6 +3021,7 @@ let translate_fun_decl (ctx : bs_ctx) (body : S.expression option) : fun_decl =
let def =
{
def_id;
+ kind = def.kind;
num_loops;
loop_id;
back_id = bid;
@@ -2853,8 +3054,7 @@ let translate_type_decls (type_decls : T.type_decl list) : type_decl list =
- optional names for the outputs values (we derive them for the backward
functions)
*)
-let translate_fun_signatures (fun_infos : FA.fun_info A.FunDeclId.Map.t)
- (type_infos : TA.type_infos)
+let translate_fun_signatures (decls_ctx : C.decls_ctx)
(functions : (A.fun_id * string option list * A.fun_sig) list) :
fun_sig_named_outputs RegularFunIdNotLoopMap.t =
(* For every function, translate the signatures of:
@@ -2865,17 +3065,14 @@ let translate_fun_signatures (fun_infos : FA.fun_info A.FunDeclId.Map.t)
(sg : A.fun_sig) : (regular_fun_id_not_loop * fun_sig_named_outputs) list
=
(* The forward function *)
- let fwd_sg =
- translate_fun_sig fun_infos fun_id type_infos sg input_names None
- in
+ let fwd_sg = translate_fun_sig decls_ctx fun_id sg input_names None in
let fwd_id = (fun_id, None) in
(* The backward functions *)
let back_sgs =
List.map
(fun (rg : T.region_var_group) ->
let tsg =
- translate_fun_sig fun_infos fun_id type_infos sg input_names
- (Some rg.id)
+ translate_fun_sig decls_ctx fun_id sg input_names (Some rg.id)
in
let id = (fun_id, Some rg.id) in
(id, tsg))
@@ -2891,3 +3088,94 @@ let translate_fun_signatures (fun_infos : FA.fun_info A.FunDeclId.Map.t)
List.fold_left
(fun m (id, sg) -> RegularFunIdNotLoopMap.add id sg m)
RegularFunIdNotLoopMap.empty translated
+
+let translate_trait_decl (type_infos : TA.type_infos)
+ (trait_decl : A.trait_decl) : trait_decl =
+ let {
+ def_id;
+ name;
+ generics;
+ preds;
+ parent_clauses;
+ consts;
+ types;
+ required_methods;
+ provided_methods;
+ } : A.trait_decl =
+ trait_decl
+ in
+ let generics = translate_generic_params generics in
+ let preds = translate_predicates preds in
+ let parent_clauses = List.map translate_trait_clause parent_clauses in
+ let consts =
+ List.map
+ (fun (name, (ty, id)) -> (name, (translate_fwd_ty type_infos ty, id)))
+ consts
+ in
+ let types =
+ List.map
+ (fun (name, (trait_clauses, ty)) ->
+ ( name,
+ ( List.map translate_trait_clause trait_clauses,
+ Option.map (translate_fwd_ty type_infos) ty ) ))
+ types
+ in
+ {
+ def_id;
+ name;
+ generics;
+ preds;
+ parent_clauses;
+ consts;
+ types;
+ required_methods;
+ provided_methods;
+ }
+
+let translate_trait_impl (type_infos : TA.type_infos)
+ (trait_impl : A.trait_impl) : trait_impl =
+ let {
+ A.def_id;
+ name;
+ impl_trait;
+ generics;
+ preds;
+ parent_trait_refs;
+ consts;
+ types;
+ required_methods;
+ provided_methods;
+ } =
+ trait_impl
+ in
+ let impl_trait =
+ translate_trait_decl_ref (translate_fwd_ty type_infos) impl_trait
+ in
+ let generics = translate_generic_params generics in
+ let preds = translate_predicates preds in
+ let parent_trait_refs = List.map translate_strait_ref parent_trait_refs in
+ let consts =
+ List.map
+ (fun (name, (ty, id)) -> (name, (translate_fwd_ty type_infos ty, id)))
+ consts
+ in
+ let types =
+ List.map
+ (fun (name, (trait_refs, ty)) ->
+ ( name,
+ ( List.map (translate_fwd_trait_ref type_infos) trait_refs,
+ translate_fwd_ty type_infos ty ) ))
+ types
+ in
+ {
+ def_id;
+ name;
+ impl_trait;
+ generics;
+ preds;
+ parent_trait_refs;
+ consts;
+ types;
+ required_methods;
+ provided_methods;
+ }
diff --git a/compiler/SynthesizeSymbolic.ml b/compiler/SynthesizeSymbolic.ml
index 857fea97..9dd65c84 100644
--- a/compiler/SynthesizeSymbolic.ml
+++ b/compiler/SynthesizeSymbolic.ml
@@ -64,7 +64,7 @@ let synthesize_symbolic_expansion (sv : V.symbolic_value)
assert (otherwise_see = None);
(* Return *)
ExpandInt (int_ty, branches, otherwise)
- | T.Adt (_, _, _, _) ->
+ | T.Adt (_, _) ->
(* Branching: it is necessarily an enumeration expansion *)
let get_variant (see : V.symbolic_expansion option) :
T.VariantId.id option * V.symbolic_value list =
@@ -85,7 +85,9 @@ let synthesize_symbolic_expansion (sv : V.symbolic_value)
match ls with
| [ (Some see, exp) ] -> ExpandNoBranch (see, exp)
| _ -> raise (Failure "Ill-formed borrow expansion"))
- | T.TypeVar _ | T.Literal Char | Never ->
+ | T.TypeVar _
+ | T.Literal Char
+ | Never | T.TraitType _ | T.Arrow _ | T.RawPtr _ ->
raise (Failure "Ill-formed symbolic expansion")
in
Some (Expansion (place, sv, expansion))
@@ -97,10 +99,10 @@ let synthesize_symbolic_expansion_no_branching (sv : V.symbolic_value)
synthesize_symbolic_expansion sv place [ Some see ] el
let synthesize_function_call (call_id : call_id) (ctx : Contexts.eval_ctx)
- (abstractions : V.AbstractionId.id list) (type_params : T.ety list)
- (const_generic_params : T.const_generic list) (args : V.typed_value list)
- (args_places : mplace option list) (dest : V.symbolic_value)
- (dest_place : mplace option) (e : expression option) : expression option =
+ (abstractions : V.AbstractionId.id list) (generics : T.egeneric_args)
+ (args : V.typed_value list) (args_places : mplace option list)
+ (dest : V.symbolic_value) (dest_place : mplace option)
+ (e : expression option) : expression option =
Option.map
(fun e ->
let call =
@@ -108,8 +110,7 @@ let synthesize_function_call (call_id : call_id) (ctx : Contexts.eval_ctx)
call_id;
ctx;
abstractions;
- type_params;
- const_generic_params;
+ generics;
args;
dest;
args_places;
@@ -123,28 +124,29 @@ let synthesize_global_eval (gid : A.GlobalDeclId.id) (dest : V.symbolic_value)
(e : expression option) : expression option =
Option.map (fun e -> EvalGlobal (gid, dest, e)) e
-let synthesize_regular_function_call (fun_id : A.fun_id)
+let synthesize_regular_function_call (fun_id : A.fun_id_or_trait_method_ref)
(call_id : V.FunCallId.id) (ctx : Contexts.eval_ctx)
- (abstractions : V.AbstractionId.id list) (type_params : T.ety list)
- (const_generic_params : T.const_generic list) (args : V.typed_value list)
- (args_places : mplace option list) (dest : V.symbolic_value)
- (dest_place : mplace option) (e : expression option) : expression option =
+ (abstractions : V.AbstractionId.id list) (generics : T.egeneric_args)
+ (args : V.typed_value list) (args_places : mplace option list)
+ (dest : V.symbolic_value) (dest_place : mplace option)
+ (e : expression option) : expression option =
synthesize_function_call
(Fun (fun_id, call_id))
- ctx abstractions type_params const_generic_params args args_places dest
- dest_place e
+ ctx abstractions generics args args_places dest dest_place e
let synthesize_unary_op (ctx : Contexts.eval_ctx) (unop : E.unop)
(arg : V.typed_value) (arg_place : mplace option) (dest : V.symbolic_value)
(dest_place : mplace option) (e : expression option) : expression option =
- synthesize_function_call (Unop unop) ctx [] [] [] [ arg ] [ arg_place ] dest
- dest_place e
+ let generics = TypesUtils.mk_empty_generic_args in
+ synthesize_function_call (Unop unop) ctx [] generics [ arg ] [ arg_place ]
+ dest dest_place e
let synthesize_binary_op (ctx : Contexts.eval_ctx) (binop : E.binop)
(arg0 : V.typed_value) (arg0_place : mplace option) (arg1 : V.typed_value)
(arg1_place : mplace option) (dest : V.symbolic_value)
(dest_place : mplace option) (e : expression option) : expression option =
- synthesize_function_call (Binop binop) ctx [] [] [] [ arg0; arg1 ]
+ let generics = TypesUtils.mk_empty_generic_args in
+ synthesize_function_call (Binop binop) ctx [] generics [ arg0; arg1 ]
[ arg0_place; arg1_place ] dest dest_place e
let synthesize_end_abstraction (ctx : Contexts.eval_ctx) (abs : V.abs)
diff --git a/compiler/Translate.ml b/compiler/Translate.ml
index 70ef5e3d..a3d96023 100644
--- a/compiler/Translate.ml
+++ b/compiler/Translate.ml
@@ -5,6 +5,7 @@ module T = Types
module A = LlbcAst
module SA = SymbolicAst
module Micro = PureMicroPasses
+module C = Contexts
open PureUtils
open TranslateCore
@@ -28,18 +29,12 @@ let translate_function_to_symbolics (trans_ctx : trans_ctx) (fdef : A.fun_decl)
("translate_function_to_symbolics: "
^ Print.fun_name_to_string fdef.A.name));
- let { type_context; fun_context; global_context } = trans_ctx in
- let fun_context = { C.fun_decls = fun_context.fun_decls } in
-
match fdef.body with
| None -> None
| Some _ ->
(* Evaluate *)
let synthesize = true in
- let inputs, symb =
- evaluate_function_symbolic synthesize type_context fun_context
- global_context fdef
- in
+ let inputs, symb = evaluate_function_symbolic synthesize trans_ctx fdef in
Some (inputs, Option.get symb)
(** Translate a function, by generating its forward and backward translations.
@@ -57,7 +52,6 @@ let translate_function_to_pure (trans_ctx : trans_ctx)
(lazy
("translate_function_to_pure: " ^ Print.fun_name_to_string fdef.A.name));
- let { type_context; fun_context; global_context } = trans_ctx in
let def_id = fdef.def_id in
(* Compute the symbolic ASTs, if the function is transparent *)
@@ -67,7 +61,7 @@ let translate_function_to_pure (trans_ctx : trans_ctx)
(* Initialize the context *)
let forward_sig =
- RegularFunIdNotLoopMap.find (A.Regular def_id, None) fun_sigs
+ RegularFunIdNotLoopMap.find (E.Regular def_id, None) fun_sigs
in
let sv_to_var = V.SymbolicValueId.Map.empty in
let var_counter = Pure.VarId.generator_zero in
@@ -82,25 +76,25 @@ let translate_function_to_pure (trans_ctx : trans_ctx)
(List.filter_map
(fun (tid, g) ->
match g with Charon.GAst.NonRec _ -> None | Rec _ -> Some tid)
- (T.TypeDeclId.Map.bindings trans_ctx.type_context.type_decls_groups))
+ (T.TypeDeclId.Map.bindings trans_ctx.type_ctx.type_decls_groups))
in
let type_context =
{
- SymbolicToPure.type_infos = type_context.type_infos;
- llbc_type_decls = type_context.type_decls;
+ SymbolicToPure.type_infos = trans_ctx.type_ctx.type_infos;
+ llbc_type_decls = trans_ctx.type_ctx.type_decls;
type_decls = pure_type_decls;
recursive_decls = recursive_type_decls;
}
in
let fun_context =
{
- SymbolicToPure.llbc_fun_decls = fun_context.fun_decls;
+ SymbolicToPure.llbc_fun_decls = trans_ctx.fun_ctx.fun_decls;
fun_sigs;
- fun_infos = fun_context.fun_infos;
+ fun_infos = trans_ctx.fun_ctx.fun_infos;
}
in
let global_context =
- { SymbolicToPure.llbc_global_decls = global_context.global_decls }
+ { SymbolicToPure.llbc_global_decls = trans_ctx.global_ctx.global_decls }
in
(* Compute the set of loops, and find better ids for them (starting at 0).
@@ -148,6 +142,8 @@ let translate_function_to_pure (trans_ctx : trans_ctx)
type_context;
fun_context;
global_context;
+ trait_decls_ctx = trans_ctx.trait_decls_ctx.trait_decls;
+ trait_impls_ctx = trans_ctx.trait_impls_ctx.trait_impls;
fun_decl = fdef;
forward_inputs = [];
(* Empty for now *)
@@ -204,7 +200,7 @@ let translate_function_to_pure (trans_ctx : trans_ctx)
(* Initialize the context - note that the ret_ty is not really
* useful as we don't translate a body *)
let backward_sg =
- RegularFunIdNotLoopMap.find (A.Regular def_id, Some back_id) fun_sigs
+ RegularFunIdNotLoopMap.find (Regular def_id, Some back_id) fun_sigs
in
let ctx = { ctx with bid = Some back_id; sg = backward_sg.sg } in
@@ -215,7 +211,7 @@ let translate_function_to_pure (trans_ctx : trans_ctx)
variables required by the backward function.
*)
let backward_sg =
- RegularFunIdNotLoopMap.find (A.Regular def_id, Some back_id) fun_sigs
+ RegularFunIdNotLoopMap.find (Regular def_id, Some back_id) fun_sigs
in
(* We need to ignore the forward inputs, and the state input (if there is) *)
let backward_inputs =
@@ -274,21 +270,18 @@ let translate_function_to_pure (trans_ctx : trans_ctx)
(* Return *)
(pure_forward, pure_backwards)
+(* TODO: factor out the return type *)
let translate_crate_to_pure (crate : A.crate) :
- trans_ctx * Pure.type_decl list * (bool * pure_fun_translation) list =
+ trans_ctx
+ * Pure.type_decl list
+ * pure_fun_translation list
+ * Pure.trait_decl list
+ * Pure.trait_impl list =
(* Debug *)
log#ldebug (lazy "translate_crate_to_pure");
- (* Compute the type and function contexts *)
- let type_context, fun_context, global_context =
- compute_type_fun_global_contexts crate
- in
- let fun_infos =
- FA.analyze_module crate fun_context.C.fun_decls
- global_context.C.global_decls !Config.use_state
- in
- let fun_context = { fun_decls = fun_context.fun_decls; fun_infos } in
- let trans_ctx = { type_context; fun_context; global_context } in
+ (* Compute the translation context *)
+ let trans_ctx = compute_contexts crate in
(* Translate all the type definitions *)
let type_decls =
@@ -304,9 +297,11 @@ let translate_crate_to_pure (crate : A.crate) :
(* Translate all the function *signatures* *)
let assumed_sigs =
List.map
- (fun (id, sg, _, _) ->
- (A.Assumed id, List.map (fun _ -> None) (sg : A.fun_sig).inputs, sg))
- Assumed.assumed_infos
+ (fun (info : Assumed.assumed_fun_info) ->
+ ( E.Assumed info.fun_id,
+ List.map (fun _ -> None) info.fun_sig.inputs,
+ info.fun_sig ))
+ Assumed.assumed_fun_infos
in
let local_sigs =
List.map
@@ -319,14 +314,11 @@ let translate_crate_to_pure (crate : A.crate) :
(fun (v : A.var) -> v.name)
(LlbcAstUtils.fun_body_get_input_vars body)
in
- (A.Regular fdef.def_id, input_names, fdef.signature))
+ (E.Regular fdef.def_id, input_names, fdef.signature))
(A.FunDeclId.Map.values crate.functions)
in
let sigs = List.append assumed_sigs local_sigs in
- let fun_sigs =
- SymbolicToPure.translate_fun_signatures fun_context.fun_infos
- type_context.type_infos sigs
- in
+ let fun_sigs = SymbolicToPure.translate_fun_signatures trans_ctx sigs in
(* Translate all the *transparent* functions *)
let pure_translations =
@@ -335,28 +327,38 @@ let translate_crate_to_pure (crate : A.crate) :
(A.FunDeclId.Map.values crate.functions)
in
+ (* Translate the trait declarations *)
+ let type_infos = trans_ctx.type_ctx.type_infos in
+ let trait_decls =
+ List.map
+ (SymbolicToPure.translate_trait_decl type_infos)
+ (T.TraitDeclId.Map.values trans_ctx.trait_decls_ctx.trait_decls)
+ in
+
+ (* Translate the trait implementations *)
+ let trait_impls =
+ List.map
+ (SymbolicToPure.translate_trait_impl type_infos)
+ (T.TraitImplId.Map.values trans_ctx.trait_impls_ctx.trait_impls)
+ in
+
(* Apply the micro-passes *)
let pure_translations =
Micro.apply_passes_to_pure_fun_translations trans_ctx pure_translations
in
(* Return *)
- (trans_ctx, type_decls, pure_translations)
-
-(** Extraction context *)
-type gen_ctx = {
- crate : A.crate;
- extract_ctx : ExtractBase.extraction_ctx;
- trans_types : Pure.type_decl Pure.TypeDeclId.Map.t;
- trans_funs : (bool * pure_fun_translation) A.FunDeclId.Map.t;
- functions_with_decreases_clause : PureUtils.FunLoopIdSet.t;
-}
+ (trans_ctx, type_decls, pure_translations, trait_decls, trait_impls)
+
+type gen_ctx = ExtractBase.extraction_ctx
type gen_config = {
extract_types : bool;
extract_decreases_clauses : bool;
extract_template_decreases_clauses : bool;
extract_fun_decls : bool;
+ extract_trait_decls : bool;
+ extract_trait_impls : bool;
extract_transparent : bool;
(** If [true], extract the transparent declarations, otherwise ignore. *)
extract_opaque : bool;
@@ -383,21 +385,23 @@ type gen_config = {
test_trans_unit_functions : bool;
}
-(** Returns the pair: (has opaque type decls, has opaque fun decls) *)
-let module_has_opaque_decls (ctx : gen_ctx) : bool * bool =
- let has_opaque_types =
- Pure.TypeDeclId.Map.exists
- (fun _ (d : Pure.type_decl) ->
- match d.kind with Opaque -> true | _ -> false)
- ctx.trans_types
- in
- let has_opaque_funs =
- A.FunDeclId.Map.exists
- (fun _ ((_, ((t_fwd, _), _)) : bool * pure_fun_translation) ->
- Option.is_none t_fwd.body)
- ctx.trans_funs
+(** Returns the pair: (has opaque type decls, has opaque fun decls).
+
+ [filter_assumed]: if [true], do not consider as opaque the external definitions
+ that we will map to definitions from the standard library.
+ *)
+let crate_has_opaque_non_builtin_decls (ctx : gen_ctx) (filter_assumed : bool) :
+ bool * bool =
+ let types, funs =
+ LlbcAstUtils.crate_get_opaque_non_builtin_decls ctx.crate filter_assumed
in
- (has_opaque_types, has_opaque_funs)
+ log#ldebug
+ (lazy
+ ("Opaque decls:" ^ "\n- types:\n"
+ ^ String.concat ",\n" (List.map T.show_type_decl types)
+ ^ "\n- functions:\n"
+ ^ String.concat ",\n" (List.map A.show_fun_decl funs)));
+ (types <> [], funs <> [])
(** Export a type declaration.
@@ -423,15 +427,19 @@ let export_type (fmt : Format.formatter) (config : gen_config) (ctx : gen_ctx)
(true, kind)
in
(* Extract, if the config instructs to do so (depending on whether the type
- * is opaque or not) *)
- if
+ is opaque or not). Remark: we don't check if the definitions are builtin
+ here but in the function [export_types_group]: the reason is that if one
+ definition in the group is builtin, then we must check that all the
+ definitions are marked builtin *)
+ let extract =
(is_opaque && config.extract_opaque)
|| ((not is_opaque) && config.extract_transparent)
- then (
+ in
+ if extract then (
if extract_decl then
- Extract.extract_type_decl ctx.extract_ctx fmt type_decl_group kind def;
+ Extract.extract_type_decl ctx fmt type_decl_group kind def;
if extract_extra_info then
- Extract.extract_type_decl_extra_info ctx.extract_ctx fmt kind def)
+ Extract.extract_type_decl_extra_info ctx fmt kind def)
(** Export a group of types.
@@ -462,41 +470,58 @@ let export_types_group (fmt : Format.formatter) (config : gen_config)
List.map (fun id -> Pure.TypeDeclId.Map.find id ctx.trans_types) ids
in
- (* Extract the type declarations.
-
- Because some declaration groups are delimited, we wrap the declarations
- between [{start,end}_type_decl_group].
+ (* Check if the definition are builtin - if yes they must be ignored.
+ Note that if one definition in the group is builtin, then all the
+ definitions must be builtin *)
+ let builtin =
+ let open ExtractBuiltin in
+ let types_map = builtin_types_map () in
+ List.map
+ (fun (def : Pure.type_decl) ->
+ let sname = name_to_simple_name def.name in
+ SimpleNameMap.find_opt sname types_map <> None)
+ defs
+ in
- Ex.:
- ====
- When targeting HOL4, the calls to [{start,end}_type_decl_group] would generate
- the [Datatype] and [End] delimiters in the snippet of code below:
+ if List.exists (fun b -> b) builtin then
+ (* Sanity check *)
+ assert (List.for_all (fun b -> b) builtin)
+ else (
+ (* Extract the type declarations.
+
+ Because some declaration groups are delimited, we wrap the declarations
+ between [{start,end}_type_decl_group].
+
+ Ex.:
+ ====
+ When targeting HOL4, the calls to [{start,end}_type_decl_group] would generate
+ the [Datatype] and [End] delimiters in the snippet of code below:
+
+ {[
+ Datatype:
+ tree =
+ TLeaf 'a
+ | TNode node ;
+
+ node =
+ Node (tree list)
+ End
+ ]}
+ *)
+ Extract.start_type_decl_group ctx fmt is_rec defs;
+ List.iteri
+ (fun i def ->
+ let kind = kind_from_index i in
+ export_type_decl kind def)
+ defs;
+ Extract.end_type_decl_group fmt is_rec defs;
- {[
- Datatype:
- tree =
- TLeaf 'a
- | TNode node ;
-
- node =
- Node (tree list)
- End
- ]}
- *)
- Extract.start_type_decl_group ctx.extract_ctx fmt is_rec defs;
- List.iteri
- (fun i def ->
- let kind = kind_from_index i in
- export_type_decl kind def)
- defs;
- Extract.end_type_decl_group fmt is_rec defs;
-
- (* Export the extra information (ex.: [Arguments] instructions in Coq) *)
- List.iteri
- (fun i def ->
- let kind = kind_from_index i in
- export_type_extra_info kind def)
- defs
+ (* Export the extra information (ex.: [Arguments] instructions in Coq) *)
+ List.iteri
+ (fun i def ->
+ let kind = kind_from_index i in
+ export_type_extra_info kind def)
+ defs)
(** Export a global declaration.
@@ -504,26 +529,34 @@ let export_types_group (fmt : Format.formatter) (config : gen_config)
*)
let export_global (fmt : Format.formatter) (config : gen_config) (ctx : gen_ctx)
(id : A.GlobalDeclId.id) : unit =
- let global_decls = ctx.extract_ctx.trans_ctx.global_context.global_decls in
+ let global_decls = ctx.trans_ctx.global_ctx.global_decls in
let global = A.GlobalDeclId.Map.find id global_decls in
- let _, ((body, loop_fwds), body_backs) =
- A.FunDeclId.Map.find global.body_id ctx.trans_funs
- in
- assert (body_backs = []);
- assert (loop_fwds = []);
+ let trans = A.FunDeclId.Map.find global.body_id ctx.trans_funs in
+ assert (trans.fwd.loops = []);
+ assert (trans.backs = []);
+ let body = trans.fwd.f in
let is_opaque = Option.is_none body.Pure.body in
- if
+ (* Check if we extract the global *)
+ let extract =
config.extract_globals
&& (((not is_opaque) && config.extract_transparent)
|| (is_opaque && config.extract_opaque))
- then
+ in
+ (* Check if it is a builtin global - if yes, we ignore it because we
+ map the definition to one in the standard library *)
+ let open ExtractBuiltin in
+ let sname = name_to_simple_name global.name in
+ let extract =
+ extract && SimpleNameMap.find_opt sname builtin_globals_map = None
+ in
+ if extract then
(* We don't wrap global declaration groups between calls to functions
[{start, end}_global_decl_group] (which don't exist): global declaration
groups are always singletons, so the [extract_global_decl] function
takes care of generating the delimiters.
*)
- Extract.extract_global_decl ctx.extract_ctx fmt global body config.interface
+ Extract.extract_global_decl ctx fmt global body config.interface
(** Utility.
@@ -604,14 +637,13 @@ let export_functions_group_scc (fmt : Format.formatter) (config : gen_config)
then
Some
(fun () ->
- Extract.extract_fun_decl ctx.extract_ctx fmt kind has_decr_clause
- def)
+ Extract.extract_fun_decl ctx fmt kind has_decr_clause def)
else None)
decls
in
let extract_defs = List.filter_map (fun x -> x) extract_defs in
if extract_defs <> [] then (
- Extract.start_fun_decl_group ctx.extract_ctx fmt is_rec decls;
+ Extract.start_fun_decl_group ctx fmt is_rec decls;
List.iter (fun f -> f ()) extract_defs;
Extract.end_fun_decl_group fmt is_rec decls)
@@ -621,82 +653,137 @@ let export_functions_group_scc (fmt : Format.formatter) (config : gen_config)
check if the forward and backward functions are mutually recursive.
*)
let export_functions_group (fmt : Format.formatter) (config : gen_config)
- (ctx : gen_ctx) (pure_ls : (bool * pure_fun_translation) list) : unit =
- (* Utility to check a function has a decrease clause *)
- let has_decreases_clause (def : Pure.fun_decl) : bool =
- PureUtils.FunLoopIdSet.mem (def.def_id, def.loop_id)
- ctx.functions_with_decreases_clause
+ (ctx : gen_ctx) (pure_ls : pure_fun_translation list) : unit =
+ (* Check if the definition are builtin - if yes they must be ignored.
+ Note that if one definition in the group is builtin, then all the
+ definitions must be builtin *)
+ let builtin =
+ let open ExtractBuiltin in
+ let funs_map = builtin_funs_map () in
+ List.map
+ (fun (trans : pure_fun_translation) ->
+ let sname = name_to_simple_name trans.fwd.f.basename in
+ SimpleNameMap.find_opt sname funs_map <> None)
+ pure_ls
in
- (* Extract the decrease clauses template bodies *)
- if config.extract_template_decreases_clauses then
- List.iter
- (fun (_, ((fwd, loop_fwds), _)) ->
- (* We only generate decreases clauses for the forward functions, because
- the termination argument should only depend on the forward inputs.
- The backward functions thus use the same decreases clauses as the
- forward function.
-
- Rem.: we might filter backward functions in {!PureMicroPasses}, but
- we don't remove forward functions. Instead, we remember if we should
- filter those functions at extraction time with a boolean (see the
- type of the [pure_ls] input parameter).
- *)
- let extract_decrease decl =
- let has_decr_clause = has_decreases_clause decl in
- if has_decr_clause then
- match !Config.backend with
- | Lean ->
- Extract.extract_template_lean_termination_and_decreasing
- ctx.extract_ctx fmt decl
- | FStar ->
- Extract.extract_template_fstar_decreases_clause ctx.extract_ctx
- fmt decl
- | Coq ->
- raise (Failure "Coq doesn't have decreases/termination clauses")
- | HOL4 ->
- raise
- (Failure "HOL4 doesn't have decreases/termination clauses")
- in
- extract_decrease fwd;
- List.iter extract_decrease loop_fwds)
- pure_ls;
-
- (* Concatenate the function definitions, filtering the useless forward
- * functions. *)
- let decls =
- List.concat
- (List.map
- (fun (keep_fwd, ((fwd, fwd_loops), (back_ls : fun_and_loops list))) ->
- let fwd = if keep_fwd then List.append fwd_loops [ fwd ] else [] in
- let back : Pure.fun_decl list =
- List.concat
- (List.map
- (fun (back, loop_backs) -> List.append loop_backs [ back ])
- back_ls)
- in
- List.append fwd back)
- pure_ls)
- in
+ if List.exists (fun b -> b) builtin then
+ (* Sanity check *)
+ assert (List.for_all (fun b -> b) builtin)
+ else
+ (* Utility to check a function has a decrease clause *)
+ let has_decreases_clause (def : Pure.fun_decl) : bool =
+ PureUtils.FunLoopIdSet.mem (def.def_id, def.loop_id)
+ ctx.functions_with_decreases_clause
+ in
- (* Extract the function definitions *)
- (if config.extract_fun_decls then
- (* Group the mutually recursive definitions *)
- let subgroups = ReorderDecls.group_reorder_fun_decls decls in
+ (* Extract the decrease clauses template bodies *)
+ if config.extract_template_decreases_clauses then
+ List.iter
+ (fun { fwd; _ } ->
+ (* We only generate decreases clauses for the forward functions, because
+ the termination argument should only depend on the forward inputs.
+ The backward functions thus use the same decreases clauses as the
+ forward function.
+
+ Rem.: we might filter backward functions in {!PureMicroPasses}, but
+ we don't remove forward functions. Instead, we remember if we should
+ filter those functions at extraction time with a boolean (see the
+ type of the [pure_ls] input parameter).
+ *)
+ let extract_decrease decl =
+ let has_decr_clause = has_decreases_clause decl in
+ if has_decr_clause then
+ match !Config.backend with
+ | Lean ->
+ Extract.extract_template_lean_termination_and_decreasing ctx
+ fmt decl
+ | FStar ->
+ Extract.extract_template_fstar_decreases_clause ctx fmt decl
+ | Coq ->
+ raise
+ (Failure "Coq doesn't have decreases/termination clauses")
+ | HOL4 ->
+ raise
+ (Failure "HOL4 doesn't have decreases/termination clauses")
+ in
+ extract_decrease fwd.f;
+ List.iter extract_decrease fwd.loops)
+ pure_ls;
+
+ (* Concatenate the function definitions, filtering the useless forward
+ * functions. *)
+ let decls =
+ List.concat
+ (List.map
+ (fun { keep_fwd; fwd; backs } ->
+ let fwd =
+ if keep_fwd then List.append fwd.loops [ fwd.f ] else []
+ in
+ let backs : Pure.fun_decl list =
+ List.concat
+ (List.map
+ (fun back -> List.append back.loops [ back.f ])
+ backs)
+ in
+ List.append fwd backs)
+ pure_ls)
+ in
- (* Extract the subgroups *)
- let export_subgroup (is_rec : bool) (decls : Pure.fun_decl list) : unit =
- export_functions_group_scc fmt config ctx is_rec decls
- in
- List.iter (fun (is_rec, decls) -> export_subgroup is_rec decls) subgroups);
-
- (* Insert unit tests if necessary *)
- if config.test_trans_unit_functions then
- List.iter
- (fun (keep_fwd, ((fwd, _), _)) ->
- if keep_fwd then
- Extract.extract_unit_test_if_unit_fun ctx.extract_ctx fmt fwd)
- pure_ls
+ (* Extract the function definitions *)
+ (if config.extract_fun_decls then
+ (* Group the mutually recursive definitions *)
+ let subgroups = ReorderDecls.group_reorder_fun_decls decls in
+
+ (* Extract the subgroups *)
+ let export_subgroup (is_rec : bool) (decls : Pure.fun_decl list) : unit =
+ export_functions_group_scc fmt config ctx is_rec decls
+ in
+ List.iter (fun (is_rec, decls) -> export_subgroup is_rec decls) subgroups);
+
+ (* Insert unit tests if necessary *)
+ if config.test_trans_unit_functions then
+ List.iter
+ (fun trans ->
+ if trans.keep_fwd then
+ Extract.extract_unit_test_if_unit_fun ctx fmt trans.fwd.f)
+ pure_ls
+
+(** Export a trait declaration. *)
+let export_trait_decl (fmt : Format.formatter) (_config : gen_config)
+ (ctx : gen_ctx) (trait_decl_id : Pure.trait_decl_id) (extract_decl : bool)
+ (extract_extra_info : bool) : unit =
+ let trait_decl = T.TraitDeclId.Map.find trait_decl_id ctx.trans_trait_decls in
+ (* Check if the trait declaration is builtin, in which case we ignore it *)
+ let open ExtractBuiltin in
+ let sname = name_to_simple_name trait_decl.name in
+ if SimpleNameMap.find_opt sname (builtin_trait_decls_map ()) = None then (
+ let ctx = { ctx with trait_decl_id = Some trait_decl.def_id } in
+ if extract_decl then Extract.extract_trait_decl ctx fmt trait_decl;
+ if extract_extra_info then
+ Extract.extract_trait_decl_extra_info ctx fmt trait_decl)
+ else ()
+
+(** Export a trait implementation. *)
+let export_trait_impl (fmt : Format.formatter) (_config : gen_config)
+ (ctx : gen_ctx) (trait_impl_id : Pure.trait_impl_id) : unit =
+ (* Lookup the definition *)
+ let trait_impl = T.TraitImplId.Map.find trait_impl_id ctx.trans_trait_impls in
+ let trait_decl =
+ Pure.TraitDeclId.Map.find trait_impl.impl_trait.trait_decl_id
+ ctx.trans_trait_decls
+ in
+ (* Check if the trait implementation is builtin *)
+ let builtin_info =
+ let open ExtractBuiltin in
+ let type_sname = name_to_simple_name trait_impl.name in
+ let trait_sname = name_to_simple_name trait_decl.name in
+ SimpleNamePairMap.find_opt (type_sname, trait_sname)
+ (builtin_trait_impls_map ())
+ in
+ match builtin_info with
+ | None -> Extract.extract_trait_impl ctx fmt trait_impl
+ | Some _ -> ()
(** A generic utility to generate the extracted definitions: as we may want to
split the definitions between different files (or not), we can control
@@ -712,12 +799,19 @@ let extract_definitions (fmt : Format.formatter) (config : gen_config)
let export_functions_group = export_functions_group fmt config ctx in
let export_global = export_global fmt config ctx in
let export_types_group = export_types_group fmt config ctx in
+ let export_trait_decl_group id =
+ export_trait_decl fmt config ctx id true false
+ in
+ let export_trait_decl_group_extra_info id =
+ export_trait_decl fmt config ctx id false true
+ in
+ let export_trait_impl = export_trait_impl fmt config ctx in
let export_state_type () : unit =
let kind =
if config.interface then ExtractBase.Declared else ExtractBase.Assumed
in
- Extract.extract_state_type fmt ctx.extract_ctx kind
+ Extract.extract_state_type fmt ctx kind
in
let export_decl_group (dg : A.declaration_group) : unit =
@@ -725,11 +819,18 @@ let extract_definitions (fmt : Format.formatter) (config : gen_config)
| Type (NonRec id) ->
if config.extract_types then export_types_group false [ id ]
| Type (Rec ids) -> if config.extract_types then export_types_group true ids
- | Fun (NonRec id) ->
+ | Fun (NonRec id) -> (
(* Lookup *)
let pure_fun = A.FunDeclId.Map.find id ctx.trans_funs in
- (* Translate *)
- export_functions_group [ pure_fun ]
+ (* Special case: we skip trait method *declarations* (we will
+ extract their type directly in the records we generate for
+ the trait declarations themselves, there is no point in having
+ separate type definitions) *)
+ match pure_fun.fwd.f.Pure.kind with
+ | TraitMethodDecl _ -> ()
+ | _ ->
+ (* Translate *)
+ export_functions_group [ pure_fun ])
| Fun (Rec ids) ->
(* General case of mutually recursive functions *)
(* Lookup *)
@@ -739,11 +840,19 @@ let extract_definitions (fmt : Format.formatter) (config : gen_config)
(* Translate *)
export_functions_group pure_funs
| Global id -> export_global id
+ | TraitDecl id ->
+ (* TODO: update to extract groups *)
+ if config.extract_trait_decls && config.extract_transparent then (
+ export_trait_decl_group id;
+ export_trait_decl_group_extra_info id)
+ | TraitImpl id ->
+ if config.extract_trait_impls && config.extract_transparent then
+ export_trait_impl id
in
(* If we need to export the state type: we try to export it after we defined
* the type definitions, because if the user wants to define a model for the
- * type, he might want to reuse those in the state type.
+ * type, they might want to reuse those in the state type.
* More specifically: if we extract functions in the same file as the type,
* we have no choice but to define the state type before the functions,
* because they may reuse this state type: in this case, we define/declare
@@ -752,37 +861,10 @@ let extract_definitions (fmt : Format.formatter) (config : gen_config)
if config.extract_state_type && config.extract_fun_decls then
export_state_type ();
- (* Obsolete: (TODO: remove) For Lean we parameterize the entire development by a section
- variable called opaque_defs, of type OpaqueDefs. The code below emits the type
- definition for OpaqueDefs, which is a structure, in which each field is one of the
- functions marked as Opaque. We emit the `structure ...` bit here, then rely on
- `extract_fun_decl` to be aware of this, and skip the keyword (e.g. "axiom" or "val")
- so as to generate valid syntax for records.
-
- We also generate such a structure only if there actually are opaque definitions. *)
- let wrap_in_sig =
- config.extract_opaque && config.extract_fun_decls
- && !Config.wrap_opaque_in_sig
- &&
- let _, opaque_funs = module_has_opaque_decls ctx in
- opaque_funs
- in
- if wrap_in_sig then (
- (* We change the name of the structure depending on whether we *only*
- extract opaque definitions, or if we extract all definitions *)
- let struct_name =
- if config.extract_transparent then "Definitions" else "OpaqueDefs"
- in
- Format.pp_print_break fmt 0 0;
- Format.pp_open_vbox fmt ctx.extract_ctx.indent_incr;
- Format.pp_print_string fmt ("structure " ^ struct_name ^ " where");
- Format.pp_print_break fmt 0 0);
List.iter export_decl_group ctx.crate.declarations;
if config.extract_state_type && not config.extract_fun_decls then
- export_state_type ();
-
- if wrap_in_sig then Format.pp_close_box fmt ()
+ export_state_type ()
type extract_file_info = {
filename : string;
@@ -904,7 +986,9 @@ let extract_file (config : gen_config) (ctx : gen_ctx) (fi : extract_file_info)
let translate_crate (filename : string) (dest_dir : string) (crate : A.crate) :
unit =
(* Translate the module to the pure AST *)
- let trans_ctx, trans_types, trans_funs = translate_crate_to_pure crate in
+ let trans_ctx, trans_types, trans_funs, trans_trait_decls, trans_trait_impls =
+ translate_crate_to_pure crate
+ in
(* Initialize the extraction context - for now we extract only to F*.
* We initialize the names map by registering the keywords used in the
@@ -916,41 +1000,27 @@ let translate_crate (filename : string) (dest_dir : string) (crate : A.crate) :
in
(* Initialize the names map (we insert the names of the "primitives"
declarations, and insert the names of the local declarations later) *)
- let mk_formatter_and_names_map = Extract.mk_formatter_and_names_map in
- let fmt, names_map =
- mk_formatter_and_names_map trans_ctx crate.name
+ let fmt, names_maps =
+ Extract.mk_formatter_and_names_maps trans_ctx crate.name
variant_concatenate_type_name
in
- (* Put everything in the context *)
- let ctx =
- {
- ExtractBase.trans_ctx;
- names_map;
- unsafe_names_map = { id_to_name = ExtractBase.IdMap.empty };
- fmt;
- indent_incr = 2;
- use_opaque_pre = !Config.split_files;
- use_dep_ite = !Config.backend = Lean && !Config.extract_decreases_clauses;
- fun_name_info = PureUtils.RegularFunIdMap.empty;
- }
- in
(* We need to compute which functions are recursive, in order to know
* whether we should generate a decrease clause or not. *)
let rec_functions =
List.map
- (fun (_, ((fwd, loop_fwds), _)) ->
- let fwd =
- if fwd.Pure.signature.info.effect_info.is_rec then
- [ (fwd.def_id, None) ]
+ (fun { fwd; _ } ->
+ let fwd_f =
+ if fwd.f.Pure.signature.info.effect_info.is_rec then
+ [ (fwd.f.def_id, None) ]
else []
in
let loop_fwds =
List.map
(fun (def : Pure.fun_decl) -> [ (def.def_id, def.loop_id) ])
- loop_fwds
+ fwd.loops
in
- fwd :: loop_fwds)
+ fwd_f :: loop_fwds)
trans_funs
in
let rec_functions : PureUtils.fun_loop_id list =
@@ -958,22 +1028,70 @@ let translate_crate (filename : string) (dest_dir : string) (crate : A.crate) :
in
let rec_functions = PureUtils.FunLoopIdSet.of_list rec_functions in
- (* Register unique names for all the top-level types, globals and functions.
+ (* Put the translated definitions in maps *)
+ let trans_types =
+ Pure.TypeDeclId.Map.of_list
+ (List.map (fun (d : Pure.type_decl) -> (d.def_id, d)) trans_types)
+ in
+ let trans_funs : pure_fun_translation A.FunDeclId.Map.t =
+ A.FunDeclId.Map.of_list
+ (List.map
+ (fun (trans : pure_fun_translation) -> (trans.fwd.f.def_id, trans))
+ trans_funs)
+ in
+
+ (* Put everything in the context *)
+ let ctx =
+ let trans_trait_decls =
+ T.TraitDeclId.Map.of_list
+ (List.map
+ (fun (d : Pure.trait_decl) -> (d.def_id, d))
+ trans_trait_decls)
+ in
+ let trans_trait_impls =
+ T.TraitImplId.Map.of_list
+ (List.map
+ (fun (d : Pure.trait_impl) -> (d.def_id, d))
+ trans_trait_impls)
+ in
+ {
+ ExtractBase.crate;
+ trans_ctx;
+ names_maps;
+ fmt;
+ indent_incr = 2;
+ use_dep_ite = !Config.backend = Lean && !Config.extract_decreases_clauses;
+ fun_name_info = PureUtils.RegularFunIdMap.empty;
+ trait_decl_id = None (* None by default *);
+ is_provided_method = false (* false by default *);
+ trans_trait_decls;
+ trans_trait_impls;
+ trans_types;
+ trans_funs;
+ functions_with_decreases_clause = rec_functions;
+ types_filter_type_args_map = Pure.TypeDeclId.Map.empty;
+ funs_filter_type_args_map = Pure.FunDeclId.Map.empty;
+ trait_impls_filter_type_args_map = Pure.TraitImplId.Map.empty;
+ }
+ in
+
+ (* Register unique names for all the top-level types, globals, functions...
* Note that the order in which we generate the names doesn't matter:
* we just need to generate a mapping from identifier to name, and make
* sure there are no name clashes. *)
let ctx =
List.fold_left
(fun ctx def -> Extract.extract_type_decl_register_names ctx def)
- ctx trans_types
+ ctx
+ (Pure.TypeDeclId.Map.values trans_types)
in
let ctx =
List.fold_left
- (fun ctx (keep_fwd, defs) ->
+ (fun ctx (trans : pure_fun_translation) ->
(* If requested by the user, register termination measures and decreases
proofs for all the recursive functions *)
- let fwd_def = fst (fst defs) in
+ let fwd_def = trans.fwd.f in
let gen_decr_clause (def : Pure.fun_decl) =
!Config.extract_decreases_clauses
&& PureUtils.FunLoopIdSet.mem
@@ -984,10 +1102,9 @@ let translate_crate (filename : string) (dest_dir : string) (crate : A.crate) :
* those are handled later *)
let is_global = fwd_def.Pure.is_global_decl_body in
if is_global then ctx
- else
- Extract.extract_fun_decl_register_names ctx keep_fwd gen_decr_clause
- defs)
- ctx trans_funs
+ else Extract.extract_fun_decl_register_names ctx gen_decr_clause trans)
+ ctx
+ (A.FunDeclId.Map.values trans_funs)
in
let ctx =
@@ -995,6 +1112,16 @@ let translate_crate (filename : string) (dest_dir : string) (crate : A.crate) :
(A.GlobalDeclId.Map.values crate.globals)
in
+ let ctx =
+ List.fold_left Extract.extract_trait_decl_register_names ctx
+ trans_trait_decls
+ in
+
+ let ctx =
+ List.fold_left Extract.extract_trait_impl_register_names ctx
+ trans_trait_impls
+ in
+
(* Open the output file *)
(* First compute the filename by replacing the extension and converting the
* case (rust module names are snake case) *)
@@ -1023,19 +1150,6 @@ let translate_crate (filename : string) (dest_dir : string) (crate : A.crate) :
(namespace, crate_name, Filename.concat dest_dir crate_name)
in
- (* Put the translated definitions in maps *)
- let trans_types =
- Pure.TypeDeclId.Map.of_list
- (List.map (fun (d : Pure.type_decl) -> (d.def_id, d)) trans_types)
- in
- let trans_funs =
- A.FunDeclId.Map.of_list
- (List.map
- (fun ((keep_fwd, (fd, bdl)) : bool * pure_fun_translation) ->
- ((fst fd).def_id, (keep_fwd, (fd, bdl))))
- trans_funs)
- in
-
let mkdir_if dest_dir =
if not (Sys.file_exists dest_dir) then (
log#linfo (lazy ("Creating missing directory: " ^ dest_dir));
@@ -1091,16 +1205,6 @@ let translate_crate (filename : string) (dest_dir : string) (crate : A.crate) :
in
(* Extract the file(s) *)
- let gen_ctx =
- {
- crate;
- extract_ctx = ctx;
- trans_types;
- trans_funs;
- functions_with_decreases_clause = rec_functions;
- }
- in
-
let module_delimiter =
match !Config.backend with
| FStar -> "."
@@ -1136,6 +1240,8 @@ let translate_crate (filename : string) (dest_dir : string) (crate : A.crate) :
extract_decreases_clauses = !Config.extract_decreases_clauses;
extract_template_decreases_clauses = false;
extract_fun_decls = false;
+ extract_trait_decls = false;
+ extract_trait_impls = false;
extract_transparent = true;
extract_opaque = false;
extract_state_type = false;
@@ -1147,7 +1253,9 @@ let translate_crate (filename : string) (dest_dir : string) (crate : A.crate) :
(* Check if there are opaque types and functions - in which case we need
* to split *)
- let has_opaque_types, has_opaque_funs = module_has_opaque_decls gen_ctx in
+ let has_opaque_types, has_opaque_funs =
+ crate_has_opaque_non_builtin_decls ctx true
+ in
let has_opaque_types = has_opaque_types || !Config.use_state in
(* Extract the types *)
@@ -1168,6 +1276,7 @@ let translate_crate (filename : string) (dest_dir : string) (crate : A.crate) :
{
base_gen_config with
extract_types = true;
+ extract_trait_decls = true;
extract_opaque = true;
extract_state_type = !Config.use_state;
interface = has_opaque_types;
@@ -1186,7 +1295,7 @@ let translate_crate (filename : string) (dest_dir : string) (crate : A.crate) :
custom_includes = [];
}
in
- extract_file types_config gen_ctx file_info;
+ extract_file types_config ctx file_info;
(* Extract the template clauses *)
(if needs_clauses_module && !Config.extract_template_decreases_clauses then
@@ -1214,9 +1323,9 @@ let translate_crate (filename : string) (dest_dir : string) (crate : A.crate) :
custom_includes = [];
}
in
- extract_file template_clauses_config gen_ctx file_info);
+ extract_file template_clauses_config ctx file_info);
- (* Extract the opaque functions, if needed *)
+ (* Extract the opaque declarations, if needed *)
let opaque_funs_module =
if has_opaque_funs then (
(* In the case of Lean we generate a template file *)
@@ -1244,17 +1353,13 @@ let translate_crate (filename : string) (dest_dir : string) (crate : A.crate) :
{
base_gen_config with
extract_fun_decls = true;
+ extract_trait_impls = true;
+ extract_globals = true;
extract_transparent = false;
extract_opaque = true;
interface = true;
}
in
- let gen_ctx =
- {
- gen_ctx with
- extract_ctx = { gen_ctx.extract_ctx with use_opaque_pre = false };
- }
- in
let file_info =
{
filename = opaque_filename;
@@ -1268,7 +1373,7 @@ let translate_crate (filename : string) (dest_dir : string) (crate : A.crate) :
custom_includes = [ types_module ];
}
in
- extract_file opaque_config gen_ctx file_info;
+ extract_file opaque_config ctx file_info;
(* Return the additional dependencies *)
[ opaque_imported_module ])
else []
@@ -1281,6 +1386,7 @@ let translate_crate (filename : string) (dest_dir : string) (crate : A.crate) :
{
base_gen_config with
extract_fun_decls = true;
+ extract_trait_impls = true;
extract_globals = true;
test_trans_unit_functions = !Config.test_trans_unit_functions;
}
@@ -1307,7 +1413,7 @@ let translate_crate (filename : string) (dest_dir : string) (crate : A.crate) :
[ types_module ] @ opaque_funs_module @ clauses_module;
}
in
- extract_file fun_config gen_ctx file_info)
+ extract_file fun_config ctx file_info)
else
let gen_config =
{
@@ -1316,6 +1422,8 @@ let translate_crate (filename : string) (dest_dir : string) (crate : A.crate) :
extract_template_decreases_clauses =
!Config.extract_template_decreases_clauses;
extract_fun_decls = true;
+ extract_trait_decls = true;
+ extract_trait_impls = true;
extract_transparent = true;
extract_opaque = true;
extract_state_type = !Config.use_state;
@@ -1337,7 +1445,7 @@ let translate_crate (filename : string) (dest_dir : string) (crate : A.crate) :
custom_includes = [];
}
in
- extract_file gen_config gen_ctx file_info);
+ extract_file gen_config ctx file_info);
(* Generate the build file *)
match !Config.backend with
diff --git a/compiler/TranslateCore.ml b/compiler/TranslateCore.ml
index ba5e237b..3427fd43 100644
--- a/compiler/TranslateCore.ml
+++ b/compiler/TranslateCore.ml
@@ -10,64 +10,69 @@ module FA = FunsAnalysis
(** The local logger *)
let log = L.translate_log
-type type_context = C.type_context [@@deriving show]
-
-type fun_context = {
- fun_decls : A.fun_decl A.FunDeclId.Map.t;
- fun_infos : FA.fun_info A.FunDeclId.Map.t;
-}
-[@@deriving show]
+type trans_ctx = C.decls_ctx [@@deriving show]
+type fun_and_loops = { f : Pure.fun_decl; loops : Pure.fun_decl list }
+type pure_fun_translation_no_loops = Pure.fun_decl * Pure.fun_decl list
-type global_context = C.global_context [@@deriving show]
+type pure_fun_translation = {
+ keep_fwd : bool;
+ (** Should we extract the forward function?
-type trans_ctx = {
- type_context : type_context;
- fun_context : fun_context;
- global_context : global_context;
+ If the forward function returns `()` and there is exactly one
+ backward function, we may merge the forward into the backward
+ function and thus don't extract the forward function)?
+ *)
+ fwd : fun_and_loops;
+ backs : fun_and_loops list;
}
-type fun_and_loops = Pure.fun_decl * Pure.fun_decl list
-type pure_fun_translation_no_loops = Pure.fun_decl * Pure.fun_decl list
-type pure_fun_translation = fun_and_loops * fun_and_loops list
+let trans_ctx_to_type_formatter (ctx : trans_ctx)
+ (type_params : Pure.type_var list)
+ (const_generic_params : Pure.const_generic_var list) :
+ PrintPure.type_formatter =
+ let type_decls = ctx.type_ctx.type_decls in
+ let global_decls = ctx.global_ctx.global_decls in
+ let trait_decls = ctx.trait_decls_ctx.trait_decls in
+ let trait_impls = ctx.trait_impls_ctx.trait_impls in
+ PrintPure.mk_type_formatter type_decls global_decls trait_decls trait_impls
+ type_params const_generic_params
let type_decl_to_string (ctx : trans_ctx) (def : Pure.type_decl) : string =
- let type_params = def.type_params in
- let cg_params = def.const_generic_params in
- let type_decls = ctx.type_context.type_decls in
- let global_decls = ctx.global_context.global_decls in
+ let generics = def.generics in
let fmt =
- PrintPure.mk_type_formatter type_decls global_decls type_params cg_params
+ trans_ctx_to_type_formatter ctx generics.types generics.const_generics
in
PrintPure.type_decl_to_string fmt def
let type_id_to_string (ctx : trans_ctx) (id : Pure.TypeDeclId.id) : string =
Print.fun_name_to_string
- (Pure.TypeDeclId.Map.find id ctx.type_context.type_decls).name
+ (Pure.TypeDeclId.Map.find id ctx.type_ctx.type_decls).name
+
+let trans_ctx_to_ast_formatter (ctx : trans_ctx)
+ (type_params : Pure.type_var list)
+ (const_generic_params : Pure.const_generic_var list) :
+ PrintPure.ast_formatter =
+ let type_decls = ctx.type_ctx.type_decls in
+ let fun_decls = ctx.fun_ctx.fun_decls in
+ let global_decls = ctx.global_ctx.global_decls in
+ let trait_decls = ctx.trait_decls_ctx.trait_decls in
+ let trait_impls = ctx.trait_impls_ctx.trait_impls in
+ PrintPure.mk_ast_formatter type_decls fun_decls global_decls trait_decls
+ trait_impls type_params const_generic_params
let fun_sig_to_string (ctx : trans_ctx) (sg : Pure.fun_sig) : string =
- let type_params = sg.type_params in
- let cg_params = sg.const_generic_params in
- let type_decls = ctx.type_context.type_decls in
- let fun_decls = ctx.fun_context.fun_decls in
- let global_decls = ctx.global_context.global_decls in
+ let generics = sg.generics in
let fmt =
- PrintPure.mk_ast_formatter type_decls fun_decls global_decls type_params
- cg_params
+ trans_ctx_to_ast_formatter ctx generics.types generics.const_generics
in
PrintPure.fun_sig_to_string fmt sg
let fun_decl_to_string (ctx : trans_ctx) (def : Pure.fun_decl) : string =
- let type_params = def.signature.type_params in
- let cg_params = def.signature.const_generic_params in
- let type_decls = ctx.type_context.type_decls in
- let fun_decls = ctx.fun_context.fun_decls in
- let global_decls = ctx.global_context.global_decls in
+ let generics = def.signature.generics in
let fmt =
- PrintPure.mk_ast_formatter type_decls fun_decls global_decls type_params
- cg_params
+ trans_ctx_to_ast_formatter ctx generics.types generics.const_generics
in
PrintPure.fun_decl_to_string fmt def
let fun_decl_id_to_string (ctx : trans_ctx) (id : A.FunDeclId.id) : string =
- Print.fun_name_to_string
- (A.FunDeclId.Map.find id ctx.fun_context.fun_decls).name
+ Print.fun_name_to_string (A.FunDeclId.Map.find id ctx.fun_ctx.fun_decls).name
diff --git a/compiler/TypesAnalysis.ml b/compiler/TypesAnalysis.ml
index 925f6d39..38d350b1 100644
--- a/compiler/TypesAnalysis.ml
+++ b/compiler/TypesAnalysis.ml
@@ -14,11 +14,10 @@ type expl_info = subtype_info [@@deriving show]
type type_borrows_info = {
contains_static : bool;
- (** Does the type (transitively) contains a static borrow? *)
- contains_borrow : bool;
- (** Does the type (transitively) contains a borrow? *)
+ (** Does the type (transitively) contain a static borrow? *)
+ contains_borrow : bool; (** Does the type (transitively) contain a borrow? *)
contains_nested_borrows : bool;
- (** Does the type (transitively) contains nested borrows? *)
+ (** Does the type (transitively) contain nested borrows? *)
contains_borrow_under_mut : bool;
}
[@@deriving show]
@@ -61,7 +60,7 @@ let initialize_g_type_info (param_infos : 'p) : 'p g_type_info =
let initialize_type_decl_info (def : type_decl) : type_decl_info =
let param_info = { under_borrow = false; under_mut_borrow = false } in
- let param_infos = List.map (fun _ -> param_info) def.type_params in
+ let param_infos = List.map (fun _ -> param_info) def.generics.types in
initialize_g_type_info param_infos
let type_decl_info_to_partial_type_info (info : type_decl_info) :
@@ -122,7 +121,7 @@ let analyze_full_ty (r_is_static : 'r -> bool) (updated : bool ref)
let rec analyze (expl_info : expl_info) (ty_info : partial_type_info)
(ty : 'r ty) : partial_type_info =
match ty with
- | Literal _ | Never -> ty_info
+ | Literal _ | Never | TraitType _ -> ty_info
| TypeVar var_id -> (
(* Update the information for the proper parameter, if necessary *)
match ty_info.param_infos with
@@ -169,22 +168,21 @@ let analyze_full_ty (r_is_static : 'r -> bool) (updated : bool ref)
in
(* Continue exploring *)
analyze expl_info ty_info rty
- | Adt
- ( (Tuple | Assumed (Box | Vec | Option | Slice | Array | Str | Range)),
- _,
- tys,
- _ ) ->
+ | RawPtr (rty, _) ->
+ (* TODO: not sure what to do here *)
+ analyze expl_info ty_info rty
+ | Adt ((Tuple | Assumed (Box | Slice | Array | Str)), generics) ->
(* Nothing to update: just explore the type parameters *)
List.fold_left
(fun ty_info ty -> analyze expl_info ty_info ty)
- ty_info tys
- | Adt (AdtId adt_id, regions, tys, _cgs) ->
+ ty_info generics.types
+ | Adt (AdtId adt_id, generics) ->
(* Lookup the information for this type definition *)
let adt_info = TypeDeclId.Map.find adt_id infos in
(* Update the type info with the information from the adt *)
let ty_info = update_ty_info ty_info adt_info.borrows_info in
(* Check if 'static appears in the region parameters *)
- let found_static = List.exists r_is_static regions in
+ let found_static = List.exists r_is_static generics.regions in
let borrows_info = ty_info.borrows_info in
let borrows_info =
{
@@ -196,7 +194,7 @@ let analyze_full_ty (r_is_static : 'r -> bool) (updated : bool ref)
let ty_info = { ty_info with borrows_info } in
(* For every instantiated type parameter: update the exploration info
* then explore the type *)
- let params_tys = List.combine adt_info.param_infos tys in
+ let params_tys = List.combine adt_info.param_infos generics.types in
let ty_info =
List.fold_left
(fun ty_info (param_info, ty) ->
@@ -235,6 +233,14 @@ let analyze_full_ty (r_is_static : 'r -> bool) (updated : bool ref)
in
(* Return *)
ty_info
+ | Arrow (inputs, output) ->
+ (* Just dive into the arrow *)
+ let ty_info =
+ List.fold_left
+ (fun ty_info ty -> analyze expl_info ty_info ty)
+ ty_info inputs
+ in
+ analyze expl_info ty_info output
in
(* Explore *)
analyze expl_info_init ty_info ty
diff --git a/compiler/Values.ml b/compiler/Values.ml
index d884c319..de27e7a9 100644
--- a/compiler/Values.ml
+++ b/compiler/Values.ml
@@ -52,6 +52,10 @@ type sv_kind =
(** The result of a loop join (when computing loop fixed points) *)
| Aggregate
(** A symbolic value we introduce in place of an aggregate value *)
+ | ConstGeneric
+ (** A symbolic value we introduce when using a const generic as a value *)
+ | TraitConst
+ (** A symbolic value we introduce when evaluating a trait associated constant *)
[@@deriving show, ord]
(** Ancestor for {!symbolic_value} iter visitor *)
diff --git a/compiler/dune b/compiler/dune
index 6785cad4..648c7325 100644
--- a/compiler/dune
+++ b/compiler/dune
@@ -12,6 +12,7 @@
(pps ppx_deriving.show ppx_deriving.ord visitors.ppx))
(libraries charon core_unix unionFind ocamlgraph)
(modules
+ AssociatedTypes
Assumed
Collections
Config
@@ -22,6 +23,8 @@
ExpressionsUtils
Extract
ExtractBase
+ ExtractBuiltin
+ ExtractTypes
FunsAnalysis
Identifiers
InterpreterBorrowsCore
@@ -90,4 +93,4 @@
-g
;-dsource
-warn-error
- -5-8-9-11-14-33-20-21-26-27-39)))
+ -5@8-9-11-14-33-20-21-26-27-39)))