diff options
author | Son HO | 2023-11-10 18:21:06 +0100 |
---|---|---|
committer | GitHub | 2023-11-10 18:21:06 +0100 |
commit | 587f1ebc0178acb19029d3fc9a729c197082aba7 (patch) | |
tree | f29805e5426f9f3fabe12d3fdadda96a1e987880 /compiler | |
parent | 7fc7c82aa61d782b335e7cf37231fd9998cd0d89 (diff) | |
parent | d300be95c28ff3147bb6f6a65992df5b9b571bdf (diff) |
Merge pull request #44 from AeneasVerif/son_traits_types
Add support for traits
Diffstat (limited to '')
44 files changed, 8964 insertions, 4758 deletions
diff --git a/compiler/AssociatedTypes.ml b/compiler/AssociatedTypes.ml new file mode 100644 index 00000000..581e218c --- /dev/null +++ b/compiler/AssociatedTypes.ml @@ -0,0 +1,681 @@ +(** This file implements utilities to handle trait associated types, in + particular with normalization helpers. + + When normalizing a type, we simplify the references to the trait associated + types, and choose a representative when there are equalities between types + enforced by local clauses (i.e., clauses of the shape [where Trait1::T = Trait2::U]). + *) + +module T = Types +module TU = TypesUtils +module V = Values +module E = Expressions +module A = LlbcAst +module C = Contexts +module Subst = Substitute +module L = Logging +module UF = UnionFind +module PA = Print.EvalCtxLlbcAst + +(** The local logger *) +let log = L.associated_types_log + +let trait_type_ref_substitute (subst : ('r, 'r1) Subst.subst) + (r : 'r C.trait_type_ref) : 'r1 C.trait_type_ref = + let { C.trait_ref; type_name } = r in + let trait_ref = Subst.trait_ref_substitute subst trait_ref in + { C.trait_ref; type_name } + +(* TODO: how not to duplicate below? *) +module RTyOrd = struct + type t = T.rty + + let compare = T.compare_rty + let to_string = T.show_rty + let pp_t = T.pp_rty + let show_t = T.show_rty +end + +module STyOrd = struct + type t = T.sty + + let compare = T.compare_sty + let to_string = T.show_sty + let pp_t = T.pp_sty + let show_t = T.show_sty +end + +module RTyMap = Collections.MakeMap (RTyOrd) +module STyMap = Collections.MakeMap (STyOrd) + +(* TODO: is it possible not to have this? *) +module type TypeWrapper = sig + type t +end + +(* TODO: don't manage to get the syntax right so using a functor *) +module MakeNormalizer + (R : TypeWrapper) + (RTyMap : Collections.Map with type key = R.t T.region T.ty) + (M : Collections.Map with type key = R.t T.region C.trait_type_ref) = +struct + let compute_norm_trait_types_from_preds + (trait_type_constraints : R.t T.region T.trait_type_constraint list) : + R.t T.region T.ty M.t = + (* Compute a union-find structure by recursively exploring the predicates and clauses *) + let norm : R.t T.region T.ty UF.elem RTyMap.t ref = ref RTyMap.empty in + let get_ref (ty : R.t T.region T.ty) : R.t T.region T.ty UF.elem = + match RTyMap.find_opt ty !norm with + | Some r -> r + | None -> + let r = UF.make ty in + norm := RTyMap.add ty r !norm; + r + in + let add_trait_type_constraint (c : R.t T.region T.trait_type_constraint) = + let trait_ty = T.TraitType (c.trait_ref, c.generics, c.type_name) in + let trait_ty_ref = get_ref trait_ty in + let ty_ref = get_ref c.ty in + let new_repr = UF.get ty_ref in + let merged = UF.union trait_ty_ref ty_ref in + (* Not sure the set operation is necessary, but I want to control which + representative is chosen *) + UF.set merged new_repr + in + (* Explore the local predicates *) + List.iter add_trait_type_constraint trait_type_constraints; + (* TODO: explore the local clauses *) + (* Compute the norm maps *) + let rbindings = + List.map (fun (k, v) -> (k, UF.get v)) (RTyMap.bindings !norm) + in + (* Filter the keys to keep only the trait type aliases *) + let rbindings = + List.filter_map + (fun (k, v) -> + match k with + | T.TraitType (trait_ref, generics, type_name) -> + assert (generics = TypesUtils.mk_empty_generic_args); + Some ({ C.trait_ref; type_name }, v) + | _ -> None) + rbindings + in + M.of_list rbindings +end + +(** Compute the representative classes of trait associated types, for normalization *) +let compute_norm_trait_stypes_from_preds + (trait_type_constraints : T.strait_type_constraint list) : + T.sty C.STraitTypeRefMap.t = + (* Compute the normalization map for the types with regions *) + let module R = struct + type t = T.region_var_id + end in + let module M = C.STraitTypeRefMap in + let module Norm = MakeNormalizer (R) (STyMap) (M) in + Norm.compute_norm_trait_types_from_preds trait_type_constraints + +(** Compute the representative classes of trait associated types, for normalization *) +let compute_norm_trait_types_from_preds + (trait_type_constraints : T.rtrait_type_constraint list) : + T.ety C.ETraitTypeRefMap.t * T.rty C.RTraitTypeRefMap.t = + (* Compute the normalization map for the types with regions *) + let module R = struct + type t = T.region_id + end in + let module M = C.RTraitTypeRefMap in + let module Norm = MakeNormalizer (R) (RTyMap) (M) in + let rbindings = + Norm.compute_norm_trait_types_from_preds trait_type_constraints + in + (* Compute the normalization map for the types with erased regions *) + let ebindings = + List.map + (fun (k, v) -> + ( trait_type_ref_substitute Subst.erase_regions_subst k, + Subst.erase_regions v )) + (M.bindings rbindings) + in + (C.ETraitTypeRefMap.of_list ebindings, rbindings) + +let ctx_add_norm_trait_stypes_from_preds (ctx : C.eval_ctx) + (trait_type_constraints : T.strait_type_constraint list) : C.eval_ctx = + let norm_trait_stypes = + compute_norm_trait_stypes_from_preds trait_type_constraints + in + { ctx with C.norm_trait_stypes } + +let ctx_add_norm_trait_types_from_preds (ctx : C.eval_ctx) + (trait_type_constraints : T.rtrait_type_constraint list) : C.eval_ctx = + let norm_trait_etypes, norm_trait_rtypes = + compute_norm_trait_types_from_preds trait_type_constraints + in + { ctx with C.norm_trait_etypes; norm_trait_rtypes } + +(** A trait instance id refers to a local clause if it only uses the variants: + [Self], [Clause], [ParentClause], [ItemClause] *) +let rec trait_instance_id_is_local_clause (id : 'r T.trait_instance_id) : bool = + match id with + | T.Self | Clause _ -> true + | TraitImpl _ | BuiltinOrAuto _ | TraitRef _ | UnknownTrait _ | FnPointer _ -> + false + | ParentClause (id, _, _) | ItemClause (id, _, _, _) -> + trait_instance_id_is_local_clause id + +(** About the conversion functions: for now we need them (TODO: merge ety, rty, etc.), + but they should be applied to types without regions. + *) +type 'r norm_ctx = { + ctx : C.eval_ctx; + get_ty_repr : 'r C.trait_type_ref -> 'r T.ty option; + convert_ety : T.ety -> 'r T.ty; (* TODO: remove? *) + convert_etrait_ref : T.etrait_ref -> 'r T.trait_ref; (* TODO: remove? *) + ty_to_string : 'r T.ty -> string; + generic_params_to_string : T.generic_params -> string; + generic_args_to_string : 'r T.generic_args -> string; + trait_ref_to_string : 'r T.trait_ref -> string; + trait_instance_id_to_string : 'r T.trait_instance_id -> string; + pp_r : Format.formatter -> 'r -> unit; +} + +(** Small utility to lookup trait impls, together with a substitution. + + Remark: one reason we have those small helpers is that all functions are + parameterized by a type variable 'r. The OCaml type inferencer and type + checker are however not very good at generating precise error messages in + this context: if in the body of the function we have an overly constrained + usage of 'r (for instance, the type inferencer deduces 'r should be + [T.erased_region]), it will not be able to pinpoint the location which + introduced the constraints and we just get a type-checking error for the + whole function. The fact that we have mutually recursive functions makes it + worse (the type-checker sometimes indicates a well-typed function as not + well-typed, because it calls a not well-typed function...). + By isolating the places where such errors typically happen in small helpers + (i.e., the places where we convert between different types of regions by + performing substitutions), we make maintenance a lot easier. + *) +let ctx_lookup_trait_impl : + 'r. + 'r norm_ctx -> + T.TraitImplId.id -> + 'r T.generic_args -> + A.trait_impl * (T.region_var_id T.region, 'r) Subst.subst = + fun ctx impl_id generics -> + (* Lookup the implementation *) + let trait_impl = C.ctx_lookup_trait_impl ctx.ctx impl_id in + (* The substitution *) + let tr_self = T.UnknownTrait __FUNCTION__ in + let subst = + Subst.make_subst_from_generics_no_regions trait_impl.generics generics + tr_self + in + (* Return *) + (trait_impl, subst) + +let ctx_lookup_trait_impl_ty : + 'r. + 'r norm_ctx -> T.TraitImplId.id -> 'r T.generic_args -> string -> 'r T.ty + = + fun ctx impl_id generics type_name -> + (* Lookup the implementation *) + let trait_impl, subst = ctx_lookup_trait_impl ctx impl_id generics in + (* Lookup the type *) + let ty = snd (List.assoc type_name trait_impl.types) in + (* Annoying: convert etype to an stype - TODO: how to avoid that? *) + let ty : T.sty = TypesUtils.ety_no_regions_to_gr_ty ty in + (* Substitute *) + Subst.ty_substitute subst ty + +let ctx_lookup_trait_impl_parent_clause : + 'r. + 'r norm_ctx -> + T.TraitImplId.id -> + 'r T.generic_args -> + T.TraitClauseId.id -> + 'r T.trait_ref = + fun ctx impl_id generics clause_id -> + (* Lookup the implementation *) + let trait_impl, subst = ctx_lookup_trait_impl ctx impl_id generics in + (* Lookup the clause *) + let clause = T.TraitClauseId.nth trait_impl.parent_trait_refs clause_id in + (* Sanity check: the clause necessarily refers to an impl *) + let _ = TypesUtils.trait_instance_id_as_trait_impl clause.trait_id in + (* Substitute *) + Subst.trait_ref_substitute subst clause + +let ctx_lookup_trait_impl_item_clause : + 'r. + 'r norm_ctx -> + T.TraitImplId.id -> + 'r T.generic_args -> + string -> + T.TraitClauseId.id -> + 'r T.trait_ref = + fun ctx impl_id generics item_name clause_id -> + (* Lookup the implementation *) + let trait_impl, subst = ctx_lookup_trait_impl ctx impl_id generics in + (* Lookup the item then its clause *) + let item = List.assoc item_name trait_impl.types in + let clause = T.TraitClauseId.nth (fst item) clause_id in + (* Sanity check: the clause necessarily refers to an impl *) + let _ = TypesUtils.trait_instance_id_as_trait_impl clause.trait_id in + (* Annoying: convert etype to an stype - TODO: how to avoid that? *) + let clause : T.strait_ref = + TypesUtils.etrait_ref_no_regions_to_gr_trait_ref clause + in + (* Substitute *) + Subst.trait_ref_substitute subst clause + +(** Normalize a type by simplifying the references to trait associated types + and choosing a representative when there are equalities between types + enforced by local clauses (i.e., `where Trait1::T = Trait2::U`. + + See the comments for {!ctx_normalize_trait_instance_id}. + *) +let rec ctx_normalize_ty : 'r. 'r norm_ctx -> 'r T.ty -> 'r T.ty = + fun ctx ty -> + log#ldebug (lazy ("ctx_normalize_ty: " ^ ctx.ty_to_string ty)); + match ty with + | T.Adt (id, generics) -> Adt (id, ctx_normalize_generic_args ctx generics) + | TypeVar _ | Literal _ | Never -> ty + | Ref (r, ty, rkind) -> + let ty = ctx_normalize_ty ctx ty in + T.Ref (r, ty, rkind) + | RawPtr (ty, rkind) -> + let ty = ctx_normalize_ty ctx ty in + RawPtr (ty, rkind) + | Arrow (inputs, output) -> + let inputs = List.map (ctx_normalize_ty ctx) inputs in + let output = ctx_normalize_ty ctx output in + Arrow (inputs, output) + | TraitType (trait_ref, generics, type_name) -> ( + log#ldebug + (lazy + ("ctx_normalize_ty:\n- trait type: " ^ ctx.ty_to_string ty + ^ "\n- trait_ref: " + ^ ctx.trait_ref_to_string trait_ref + ^ "\n- raw trait ref:\n" + ^ T.show_trait_ref ctx.pp_r trait_ref + ^ "\n- generics:\n" + ^ ctx.generic_args_to_string generics)); + (* Normalize and attempt to project the type from the trait ref *) + let trait_ref = ctx_normalize_trait_ref ctx trait_ref in + let generics = ctx_normalize_generic_args ctx generics in + (* For now, we don't support higher order types *) + assert (generics = TypesUtils.mk_empty_generic_args); + let ty : 'r T.ty = + match trait_ref.trait_id with + | T.TraitRef + { T.trait_id = T.TraitImpl impl_id; generics = ref_generics; _ } -> + assert (ref_generics = TypesUtils.mk_empty_generic_args); + log#ldebug + (lazy + ("ctx_normalize_ty: trait type: trait ref: " + ^ ctx.ty_to_string ty)); + (* Lookup the type *) + let ty = + ctx_lookup_trait_impl_ty ctx impl_id trait_ref.generics type_name + in + (* Normalize *) + ctx_normalize_ty ctx ty + | T.TraitImpl impl_id -> + log#ldebug + (lazy + ("ctx_normalize_ty (trait impl):\n- trait type: " + ^ ctx.ty_to_string ty ^ "\n- trait_ref: " + ^ ctx.trait_ref_to_string trait_ref + ^ "\n- raw trait ref:\n" + ^ T.show_trait_ref ctx.pp_r trait_ref)); + (* This happens. This doesn't come from the substitutions + performed by Aeneas (the [TraitImpl] would be wrapped in a + [TraitRef] but from non-normalized traits translated from + the Rustc AST. + TODO: factor out with the branch above. + *) + (* Lookup the type *) + let ty = + ctx_lookup_trait_impl_ty ctx impl_id trait_ref.generics type_name + in + (* Normalize *) + ctx_normalize_ty ctx ty + | _ -> + log#ldebug + (lazy + ("ctx_normalize_ty: trait type: not a trait ref: " + ^ ctx.ty_to_string ty ^ "\n- trait_ref: " + ^ ctx.trait_ref_to_string trait_ref + ^ "\n- raw trait ref:\n" + ^ T.show_trait_ref ctx.pp_r trait_ref)); + (* We can't project *) + assert (trait_instance_id_is_local_clause trait_ref.trait_id); + T.TraitType (trait_ref, generics, type_name) + in + let tr : 'r C.trait_type_ref = { C.trait_ref; type_name } in + (* Lookup the representative, if there is *) + match ctx.get_ty_repr tr with None -> ty | Some ty -> ty) + +(** This returns the normalized trait instance id together with an optional + reference to a trait **implementation** (the `trait_ref` we return has + necessarily for instance id a [TraitImpl]). + + We need this in particular to simplify the trait instance ids after we + performed a substitution. + + Example: + ======== + {[ + trait Trait { + type S + } + + impl TraitImpl for Foo { + type S = usize + } + + fn f<T : Trait>(...) -> T::S; + + ... + let x = f<Foo>[TraitImpl](...); + (* The return type of the call to f is: + T::S ~~> TraitImpl::S ~~> usize + *) + ]} + + Several remarks: + - as we do not allow higher-order types (yet) then local clauses (and + sub-clauses) can't have generic arguments + - the [TraitRef] case only happens because of substitution, the role of + the normalization is in particular to eliminate it. Inside a [TraitRef] + there is necessarily: + - an id referencing a local (sub-)clause, that is an id using the variants + [Self], [Clause], [ItemClause] and [ParentClause] exclusively. We can't + simplify those cases: all we can do is remove the [TraitRef] wrapper + by leveraging the fact that the generic arguments must be empty. + - a [TraitImpl]. Note that the [TraitImpl] is necessarily just a [TraitImpl], + it can't be for instance a [ParentClause(TraitImpl ...)] because the + trait resolution would then directly reference the implementation + designated by [ParentClause(TraitImpl ...)] (and same for the other cases). + In this case we can lookup the trait implementation and recursively project + over it. + *) +and ctx_normalize_trait_instance_id : + 'r. + 'r norm_ctx -> + 'r T.trait_instance_id -> + 'r T.trait_instance_id * 'r T.trait_ref option = + fun ctx id -> + match id with + | Self -> (id, None) + | TraitImpl _ -> + (* The [TraitImpl] shouldn't be inside any projection - we check this + elsewhere by asserting that whenever we return [None] for the impl + trait ref, then the id actually refers to a local clause. *) + (id, None) + | Clause _ -> (id, None) + | BuiltinOrAuto _ -> (id, None) + | ParentClause (inst_id, decl_id, clause_id) -> ( + let inst_id, impl = ctx_normalize_trait_instance_id ctx inst_id in + (* Check if the inst_id refers to a specific implementation, if yes project *) + match impl with + | None -> + (* This is actually a local clause *) + assert (trait_instance_id_is_local_clause inst_id); + (ParentClause (inst_id, decl_id, clause_id), None) + | Some impl -> + (* We figure out the parent clause by doing the following: + {[ + // The implementation we are looking at + impl Impl1 : Trait1 { ... } + + // Check the trait it implements + trait Trait1 : ParentTrait1 + ParentTrait2 { ... } + ^^^^^^^^^^^^^^^^^^^^^^^^^^^ + those are the parent clauses + ]} + *) + (* Lookup the clause *) + let impl_id = + TypesUtils.trait_instance_id_as_trait_impl impl.trait_id + in + let clause = + ctx_lookup_trait_impl_parent_clause ctx impl_id impl.generics + clause_id + in + (* Normalize the clause *) + let clause = ctx_normalize_trait_ref ctx clause in + (TraitRef clause, Some clause)) + | ItemClause (inst_id, decl_id, item_name, clause_id) -> ( + let inst_id, impl = ctx_normalize_trait_instance_id ctx inst_id in + (* Check if the inst_id refers to a specific implementation, if yes project *) + match impl with + | None -> + (* This is actually a local clause *) + assert (trait_instance_id_is_local_clause inst_id); + (ItemClause (inst_id, decl_id, item_name, clause_id), None) + | Some impl -> + (* We figure out the item clause by doing the following: + {[ + // The implementation we are looking at + impl Impl1 : Trait1<R> { + type S = ... + with Impl2 : Trait2 ... // Instances satisfying the declared bounds + ^^^^^^^^^^^^^^^^^^ + Lookup the clause from here + } + ]} + *) + (* Lookup the impl *) + let impl_id = + TypesUtils.trait_instance_id_as_trait_impl impl.trait_id + in + let clause = + ctx_lookup_trait_impl_item_clause ctx impl_id impl.generics + item_name clause_id + in + (* Normalize the clause *) + let clause = ctx_normalize_trait_ref ctx clause in + (TraitRef clause, Some clause)) + | TraitRef { T.trait_id = T.TraitImpl trait_id; generics; trait_decl_ref } -> + (* We can't simplify the id *yet* : we will simplify it when projecting. + However, we have an implementation to return *) + (* Normalize the generics *) + let generics = ctx_normalize_generic_args ctx generics in + let trait_decl_ref = ctx_normalize_trait_decl_ref ctx trait_decl_ref in + let trait_ref : 'r T.trait_ref = + { T.trait_id = T.TraitImpl trait_id; generics; trait_decl_ref } + in + (TraitRef trait_ref, Some trait_ref) + | TraitRef trait_ref -> + (* The trait instance id necessarily refers to a local sub-clause. We + can't project over it and can only peel off the [TraitRef] wrapper *) + assert (trait_instance_id_is_local_clause trait_ref.trait_id); + assert (trait_ref.generics = TypesUtils.mk_empty_generic_args); + (trait_ref.trait_id, None) + | FnPointer ty -> + let ty = ctx_normalize_ty ctx ty in + (* TODO: we might want to return the ref to the function pointer, + in order to later normalize a call to this function pointer *) + (FnPointer ty, None) + | UnknownTrait _ -> + (* This is actually an error case *) + (id, None) + +and ctx_normalize_generic_args (ctx : 'r norm_ctx) + (generics : 'r T.generic_args) : 'r T.generic_args = + let { T.regions; types; const_generics; trait_refs } = generics in + let types = List.map (ctx_normalize_ty ctx) types in + let trait_refs = List.map (ctx_normalize_trait_ref ctx) trait_refs in + { T.regions; types; const_generics; trait_refs } + +and ctx_normalize_trait_ref (ctx : 'r norm_ctx) (trait_ref : 'r T.trait_ref) : + 'r T.trait_ref = + log#ldebug + (lazy + ("ctx_normalize_trait_ref: " + ^ ctx.trait_ref_to_string trait_ref + ^ "\n- raw trait ref:\n" + ^ T.show_trait_ref ctx.pp_r trait_ref)); + let { T.trait_id; generics; trait_decl_ref } = trait_ref in + (* Check if the id is an impl, otherwise normalize it *) + let trait_id, norm_trait_ref = ctx_normalize_trait_instance_id ctx trait_id in + match norm_trait_ref with + | None -> + log#ldebug + (lazy + ("ctx_normalize_trait_ref: no norm: " + ^ ctx.trait_instance_id_to_string trait_id)); + let generics = ctx_normalize_generic_args ctx generics in + let trait_decl_ref = ctx_normalize_trait_decl_ref ctx trait_decl_ref in + { T.trait_id; generics; trait_decl_ref } + | Some trait_ref -> + log#ldebug + (lazy + ("ctx_normalize_trait_ref: normalized to: " + ^ ctx.trait_ref_to_string trait_ref)); + assert (generics = TypesUtils.mk_empty_generic_args); + trait_ref + +(* Not sure this one is really necessary *) +and ctx_normalize_trait_decl_ref (ctx : 'r norm_ctx) + (trait_decl_ref : 'r T.trait_decl_ref) : 'r T.trait_decl_ref = + let { T.trait_decl_id; decl_generics } = trait_decl_ref in + let decl_generics = ctx_normalize_generic_args ctx decl_generics in + { T.trait_decl_id; decl_generics } + +let ctx_normalize_trait_type_constraint (ctx : 'r norm_ctx) + (ttc : 'r T.trait_type_constraint) : 'r T.trait_type_constraint = + let { T.trait_ref; generics; type_name; ty } = ttc in + let trait_ref = ctx_normalize_trait_ref ctx trait_ref in + let generics = ctx_normalize_generic_args ctx generics in + let ty = ctx_normalize_ty ctx ty in + { T.trait_ref; generics; type_name; ty } + +let generic_params_to_string ctx x = + "<" ^ String.concat ", " (fst (PA.generic_params_to_strings ctx x)) ^ ">" + +let mk_snorm_ctx (ctx : C.eval_ctx) : T.RegionVarId.id T.region norm_ctx = + let get_ty_repr x = C.STraitTypeRefMap.find_opt x ctx.norm_trait_stypes in + { + ctx; + get_ty_repr; + convert_ety = TypesUtils.ety_no_regions_to_sty; + convert_etrait_ref = TypesUtils.etrait_ref_no_regions_to_gr_trait_ref; + ty_to_string = PA.sty_to_string ctx; + generic_params_to_string = generic_params_to_string ctx; + generic_args_to_string = PA.sgeneric_args_to_string ctx; + trait_ref_to_string = PA.strait_ref_to_string ctx; + trait_instance_id_to_string = PA.strait_instance_id_to_string ctx; + pp_r = T.pp_region T.pp_region_var_id; + } + +let mk_rnorm_ctx (ctx : C.eval_ctx) : T.RegionId.id T.region norm_ctx = + let get_ty_repr x = C.RTraitTypeRefMap.find_opt x ctx.norm_trait_rtypes in + { + ctx; + get_ty_repr; + convert_ety = TypesUtils.ety_no_regions_to_rty; + convert_etrait_ref = TypesUtils.etrait_ref_no_regions_to_gr_trait_ref; + ty_to_string = PA.rty_to_string ctx; + generic_params_to_string = generic_params_to_string ctx; + generic_args_to_string = PA.rgeneric_args_to_string ctx; + trait_ref_to_string = PA.rtrait_ref_to_string ctx; + trait_instance_id_to_string = PA.rtrait_instance_id_to_string ctx; + pp_r = T.pp_region T.pp_region_id; + } + +let mk_enorm_ctx (ctx : C.eval_ctx) : T.erased_region norm_ctx = + let get_ty_repr x = C.ETraitTypeRefMap.find_opt x ctx.norm_trait_etypes in + { + ctx; + get_ty_repr; + convert_ety = (fun x -> x); + convert_etrait_ref = (fun x -> x); + ty_to_string = PA.ety_to_string ctx; + generic_params_to_string = generic_params_to_string ctx; + generic_args_to_string = PA.egeneric_args_to_string ctx; + trait_ref_to_string = PA.etrait_ref_to_string ctx; + trait_instance_id_to_string = PA.etrait_instance_id_to_string ctx; + pp_r = T.pp_erased_region; + } + +let ctx_normalize_sty (ctx : C.eval_ctx) (ty : T.sty) : T.sty = + ctx_normalize_ty (mk_snorm_ctx ctx) ty + +let ctx_normalize_rty (ctx : C.eval_ctx) (ty : T.rty) : T.rty = + ctx_normalize_ty (mk_rnorm_ctx ctx) ty + +let ctx_normalize_ety (ctx : C.eval_ctx) (ty : T.ety) : T.ety = + ctx_normalize_ty (mk_enorm_ctx ctx) ty + +let ctx_normalize_rtrait_type_constraint (ctx : C.eval_ctx) + (ttc : T.rtrait_type_constraint) : T.rtrait_type_constraint = + ctx_normalize_trait_type_constraint (mk_rnorm_ctx ctx) ttc + +(** Same as [type_decl_get_instantiated_variants_fields_rtypes] but normalizes the types *) +let type_decl_get_inst_norm_variants_fields_rtypes (ctx : C.eval_ctx) + (def : T.type_decl) (generics : T.rgeneric_args) : + (T.VariantId.id option * T.rty list) list = + let res = + Subst.type_decl_get_instantiated_variants_fields_rtypes def generics + in + List.map + (fun (variant_id, types) -> + (variant_id, List.map (ctx_normalize_rty ctx) types)) + res + +(** Same as [type_decl_get_instantiated_field_rtypes] but normalizes the types *) +let type_decl_get_inst_norm_field_rtypes (ctx : C.eval_ctx) (def : T.type_decl) + (opt_variant_id : T.VariantId.id option) (generics : T.rgeneric_args) : + T.rty list = + let types = + Subst.type_decl_get_instantiated_field_rtypes def opt_variant_id generics + in + List.map (ctx_normalize_rty ctx) types + +(** Same as [ctx_adt_value_get_instantiated_field_rtypes] but normalizes the types *) +let ctx_adt_value_get_inst_norm_field_rtypes (ctx : C.eval_ctx) + (adt : V.adt_value) (id : T.type_id) (generics : T.rgeneric_args) : + T.rty list = + let types = + Subst.ctx_adt_value_get_instantiated_field_rtypes ctx adt id generics + in + List.map (ctx_normalize_rty ctx) types + +(** Same as [ctx_adt_value_get_instantiated_field_etypes] but normalizes the types *) +let type_decl_get_inst_norm_field_etypes (ctx : C.eval_ctx) (def : T.type_decl) + (opt_variant_id : T.VariantId.id option) (generics : T.egeneric_args) : + T.ety list = + let types = + Subst.type_decl_get_instantiated_field_etypes def opt_variant_id generics + in + List.map (ctx_normalize_ety ctx) types + +(** Same as [ctx_adt_get_instantiated_field_etypes] but normalizes the types *) +let ctx_adt_get_inst_norm_field_etypes (ctx : C.eval_ctx) + (def_id : T.TypeDeclId.id) (opt_variant_id : T.VariantId.id option) + (generics : T.egeneric_args) : T.ety list = + let types = + Subst.ctx_adt_get_instantiated_field_etypes ctx def_id opt_variant_id + generics + in + List.map (ctx_normalize_ety ctx) types + +(** Same as [substitute_signature] but normalizes the types *) +let ctx_subst_norm_signature (ctx : C.eval_ctx) + (asubst : T.RegionGroupId.id -> V.AbstractionId.id) + (r_subst : T.RegionVarId.id -> T.RegionId.id) + (ty_subst : T.TypeVarId.id -> T.rty) + (cg_subst : T.ConstGenericVarId.id -> T.const_generic) + (tr_subst : T.TraitClauseId.id -> T.rtrait_instance_id) + (tr_self : T.rtrait_instance_id) (sg : A.fun_sig) : A.inst_fun_sig = + let sg = + Subst.substitute_signature asubst r_subst ty_subst cg_subst tr_subst tr_self + sg + in + let { A.regions_hierarchy; inputs; output; trait_type_constraints } = sg in + let inputs = List.map (ctx_normalize_rty ctx) inputs in + let output = ctx_normalize_rty ctx output in + let trait_type_constraints = + List.map (ctx_normalize_rtrait_type_constraint ctx) trait_type_constraints + in + { regions_hierarchy; inputs; output; trait_type_constraints } diff --git a/compiler/Assumed.ml b/compiler/Assumed.ml index 11cd5666..79f6b0d4 100644 --- a/compiler/Assumed.ml +++ b/compiler/Assumed.ml @@ -63,200 +63,52 @@ module Sig = struct let empty_const_generic_params : T.const_generic_var list = [] + let mk_generic_args regions types const_generics : T.sgeneric_args = + { regions; types; const_generics; trait_refs = [] } + + let mk_generic_params regions types const_generics : T.generic_params = + { regions; types; const_generics; trait_clauses = [] } + let mk_ref_ty (r : T.RegionVarId.id T.region) (ty : T.sty) (is_mut : bool) : T.sty = let ref_kind = if is_mut then T.Mut else T.Shared in mk_ref_ty r ty ref_kind let mk_array_ty (ty : T.sty) (cg : T.const_generic) : T.sty = - Adt (Assumed Array, [], [ ty ], [ cg ]) + Adt (Assumed Array, mk_generic_args [] [ ty ] [ cg ]) - let mk_slice_ty (ty : T.sty) : T.sty = Adt (Assumed Slice, [], [ ty ], []) - let range_ty : T.sty = Adt (Assumed Range, [], [ usize_ty ], []) + let mk_slice_ty (ty : T.sty) : T.sty = + Adt (Assumed Slice, mk_generic_args [] [ ty ] []) - (** [fn<T>(&'a mut T, T) -> T] *) - let mem_replace_sig : A.fun_sig = - (* The signature fields *) - let region_params = [ region_param_0 ] (* <'a> *) in - let regions_hierarchy = [ region_group_0 ] (* [{<'a>}] *) in - let type_params = [ type_param_0 ] (* <T> *) in - let inputs = - [ mk_ref_ty rvar_0 tvar_0 true (* &'a mut T *); tvar_0 (* T *) ] + let mk_sig generics regions_hierarchy inputs output : A.fun_sig = + let preds : T.predicates = + { regions_outlive = []; types_outlive = []; trait_type_constraints = [] } in - let output = tvar_0 (* T *) in { - region_params; - num_early_bound_regions = 0; + is_unsafe = false; + generics; + preds; + parent_params_info = None; regions_hierarchy; - type_params; - const_generic_params = empty_const_generic_params; inputs; output; } (** [fn<T>(T) -> Box<T>] *) let box_new_sig : A.fun_sig = - { - region_params = []; - num_early_bound_regions = 0; - regions_hierarchy = []; - type_params = [ type_param_0 ] (* <T> *); - const_generic_params = empty_const_generic_params; - inputs = [ tvar_0 (* T *) ]; - output = mk_box_ty tvar_0 (* Box<T> *); - } + let generics = mk_generic_params [] [ type_param_0 ] [] (* <T> *) in + let regions_hierarchy = [] in + let inputs = [ tvar_0 (* T *) ] in + let output = mk_box_ty tvar_0 (* Box<T> *) in + mk_sig generics regions_hierarchy inputs output (** [fn<T>(Box<T>) -> ()] *) let box_free_sig : A.fun_sig = - { - region_params = []; - num_early_bound_regions = 0; - regions_hierarchy = []; - type_params = [ type_param_0 ] (* <T> *); - const_generic_params = empty_const_generic_params; - inputs = [ mk_box_ty tvar_0 (* Box<T> *) ]; - output = mk_unit_ty (* () *); - } - - (** Helper for [Box::deref_shared] and [Box::deref_mut]. - Returns: - [fn<'a, T>(&'a (mut) Box<T>) -> &'a (mut) T] - *) - let box_deref_gen_sig (is_mut : bool) : A.fun_sig = - (* The signature fields *) - let region_params = [ region_param_0 ] in - let regions_hierarchy = [ region_group_0 ] (* <'a> *) in - { - region_params; - num_early_bound_regions = 0; - regions_hierarchy; - type_params = [ type_param_0 ] (* <T> *); - const_generic_params = empty_const_generic_params; - inputs = - [ mk_ref_ty rvar_0 (mk_box_ty tvar_0) is_mut (* &'a (mut) Box<T> *) ]; - output = mk_ref_ty rvar_0 tvar_0 is_mut (* &'a (mut) T *); - } - - (** [fn<'a, T>(&'a Box<T>) -> &'a T] *) - let box_deref_shared_sig = box_deref_gen_sig false - - (** [fn<'a, T>(&'a mut Box<T>) -> &'a mut T] *) - let box_deref_mut_sig = box_deref_gen_sig true - - (** [fn<T>() -> Vec<T>] *) - let vec_new_sig : A.fun_sig = - let region_params = [] in + let generics = mk_generic_params [] [ type_param_0 ] [] (* <T> *) in let regions_hierarchy = [] in - let type_params = [ type_param_0 ] (* <T> *) in - let inputs = [] in - let output = mk_vec_ty tvar_0 (* Vec<T> *) in - { - region_params; - num_early_bound_regions = 0; - regions_hierarchy; - type_params; - const_generic_params = empty_const_generic_params; - inputs; - output; - } - - (** [fn<T>(&'a mut Vec<T>, T)] *) - let vec_push_sig : A.fun_sig = - (* The signature fields *) - let region_params = [ region_param_0 ] in - let regions_hierarchy = [ region_group_0 ] (* <'a> *) in - let type_params = [ type_param_0 ] (* <T> *) in - let inputs = - [ - mk_ref_ty rvar_0 (mk_vec_ty tvar_0) true (* &'a mut Vec<T> *); - tvar_0 (* T *); - ] - in + let inputs = [ mk_box_ty tvar_0 (* Box<T> *) ] in let output = mk_unit_ty (* () *) in - { - region_params; - num_early_bound_regions = 0; - regions_hierarchy; - type_params; - const_generic_params = empty_const_generic_params; - inputs; - output; - } - - (** [fn<T>(&'a mut Vec<T>, usize, T)] *) - let vec_insert_sig : A.fun_sig = - (* The signature fields *) - let region_params = [ region_param_0 ] in - let regions_hierarchy = [ region_group_0 ] (* <'a> *) in - let type_params = [ type_param_0 ] (* <T> *) in - let inputs = - [ - mk_ref_ty rvar_0 (mk_vec_ty tvar_0) true (* &'a mut Vec<T> *); - mk_usize_ty (* usize *); - tvar_0 (* T *); - ] - in - let output = mk_unit_ty (* () *) in - { - region_params; - num_early_bound_regions = 0; - regions_hierarchy; - type_params; - const_generic_params = empty_const_generic_params; - inputs; - output; - } - - (** [fn<T>(&'a Vec<T>) -> usize] *) - let vec_len_sig : A.fun_sig = - (* The signature fields *) - let region_params = [ region_param_0 ] in - let regions_hierarchy = [ region_group_0 ] (* <'a> *) in - let type_params = [ type_param_0 ] (* <T> *) in - let inputs = - [ mk_ref_ty rvar_0 (mk_vec_ty tvar_0) false (* &'a Vec<T> *) ] - in - let output = mk_usize_ty (* usize *) in - { - region_params; - num_early_bound_regions = 0; - regions_hierarchy; - type_params; - const_generic_params = empty_const_generic_params; - inputs; - output; - } - - (** Helper: - [fn<T>(&'a (mut) Vec<T>, usize) -> &'a (mut) T] - *) - let vec_index_gen_sig (is_mut : bool) : A.fun_sig = - (* The signature fields *) - let region_params = [ region_param_0 ] in - let regions_hierarchy = [ region_group_0 ] (* <'a> *) in - let type_params = [ type_param_0 ] (* <T> *) in - let inputs = - [ - mk_ref_ty rvar_0 (mk_vec_ty tvar_0) is_mut (* &'a (mut) Vec<T> *); - mk_usize_ty (* usize *); - ] - in - let output = mk_ref_ty rvar_0 tvar_0 is_mut (* &'a (mut) T *) in - { - region_params; - num_early_bound_regions = 0; - regions_hierarchy; - type_params; - const_generic_params = empty_const_generic_params; - inputs; - output; - } - - (** [fn<T>(&'a Vec<T>, usize) -> &'a T] *) - let vec_index_shared_sig : A.fun_sig = vec_index_gen_sig false - - (** [fn<T>(&'a mut Vec<T>, usize) -> &'a mut T] *) - let vec_index_mut_sig : A.fun_sig = vec_index_gen_sig true + mk_sig generics regions_hierarchy inputs output (** Array/slice functions *) @@ -275,10 +127,10 @@ module Sig = struct let mk_array_slice_borrow_sig (cgs : T.const_generic_var list) (input_ty : T.TypeVarId.id -> T.sty) (index_ty : T.sty option) (output_ty : T.TypeVarId.id -> T.sty) (is_mut : bool) : A.fun_sig = - (* The signature fields *) - let region_params = [ region_param_0 ] in + let generics = + mk_generic_params [ region_param_0 ] [ type_param_0 ] cgs (* <'a, T> *) + in let regions_hierarchy = [ region_group_0 ] (* <'a> *) in - let type_params = [ type_param_0 ] (* <T> *) in let inputs = [ mk_ref_ty rvar_0 @@ -294,15 +146,7 @@ module Sig = struct (output_ty type_param_0.index) is_mut (* &'a (mut) output_ty<T> *) in - { - region_params; - num_early_bound_regions = 0; - regions_hierarchy; - type_params; - const_generic_params = cgs; - inputs; - output; - } + mk_sig generics regions_hierarchy inputs output let mk_array_slice_index_sig (is_array : bool) (is_mut : bool) : A.fun_sig = (* Array<T, N> *) @@ -328,50 +172,53 @@ module Sig = struct let cgs = [ cg_param_0 ] in mk_array_slice_borrow_sig cgs input_ty None output_ty is_mut - let mk_array_slice_subslice_sig (is_array : bool) (is_mut : bool) : A.fun_sig - = - (* Array<T, N> *) - let input_ty id = - if is_array then mk_array_ty (T.TypeVar id) cgvar_0 - else mk_slice_ty (T.TypeVar id) + let array_repeat_sig = + let generics = + (* <T, N> *) + mk_generic_params [] [ type_param_0 ] [ cg_param_0 ] in - (* Range *) - let index_ty = range_ty in - (* Slice<T> *) - let output_ty id = mk_slice_ty (T.TypeVar id) in - let cgs = if is_array then [ cg_param_0 ] else [] in - mk_array_slice_borrow_sig cgs input_ty (Some index_ty) output_ty is_mut - - let array_subslice_sig (is_mut : bool) = - mk_array_slice_subslice_sig true is_mut - - let slice_subslice_sig (is_mut : bool) = - mk_array_slice_subslice_sig false is_mut + let regions_hierarchy = [] (* <> *) in + let inputs = [ tvar_0 (* T *) ] in + let output = + (* [T; N] *) + mk_array_ty tvar_0 cgvar_0 + in + mk_sig generics regions_hierarchy inputs output (** Helper: [fn<T>(&'a [T]) -> usize] *) let slice_len_sig : A.fun_sig = - (* The signature fields *) - let region_params = [ region_param_0 ] in + let generics = + mk_generic_params [ region_param_0 ] [ type_param_0 ] [] (* <'a, T> *) + in let regions_hierarchy = [ region_group_0 ] (* <'a> *) in - let type_params = [ type_param_0 ] (* <T> *) in let inputs = [ mk_ref_ty rvar_0 (mk_slice_ty tvar_0) false (* &'a [T] *) ] in let output = mk_usize_ty (* usize *) in - { - region_params; - num_early_bound_regions = 0; - regions_hierarchy; - type_params; - const_generic_params = empty_const_generic_params; - inputs; - output; - } + mk_sig generics regions_hierarchy inputs output end -type assumed_info = A.assumed_fun_id * A.fun_sig * bool * name +type raw_assumed_fun_info = + A.assumed_fun_id * A.fun_sig * bool * name * bool list option + +type assumed_fun_info = { + fun_id : A.assumed_fun_id; + fun_sig : A.fun_sig; + can_fail : bool; + name : name; + keep_types : bool list option; + (** We may want to filter some type arguments. + + For instance, all the `Vec` functions (and the `Vec` type itself) take + an `Allocator` type as argument, that we ignore. + *) +} + +let mk_assumed_fun_info (raw : raw_assumed_fun_info) : assumed_fun_info = + let fun_id, fun_sig, can_fail, name, keep_types = raw in + { fun_id; fun_sig; can_fail; name; keep_types } (** The list of assumed functions and all their information: - their signature @@ -384,94 +231,72 @@ type assumed_info = A.assumed_fun_id * A.fun_sig * bool * name a [usize], we have to make sure that vectors are bounded by the max usize. As a consequence, [Vec::push] is monadic. *) -let assumed_infos : assumed_info list = - let deref_pre = [ "core"; "ops"; "deref" ] in - let vec_pre = [ "alloc"; "vec"; "Vec" ] in - let index_pre = [ "core"; "ops"; "index" ] in +let raw_assumed_fun_infos : raw_assumed_fun_info list = [ - (A.Replace, Sig.mem_replace_sig, false, to_name [ "core"; "mem"; "replace" ]); - (BoxNew, Sig.box_new_sig, false, to_name [ "alloc"; "boxed"; "Box"; "new" ]); + ( BoxNew, + Sig.box_new_sig, + false, + to_name [ "alloc"; "boxed"; "Box"; "new" ], + Some [ true; false ] ); + (* BoxFree shouldn't be used *) ( BoxFree, Sig.box_free_sig, false, - to_name [ "alloc"; "boxed"; "Box"; "free" ] ); - ( BoxDeref, - Sig.box_deref_shared_sig, - false, - to_name (deref_pre @ [ "Deref"; "deref" ]) ); - ( BoxDerefMut, - Sig.box_deref_mut_sig, - false, - to_name (deref_pre @ [ "DerefMut"; "deref_mut" ]) ); - (VecNew, Sig.vec_new_sig, false, to_name (vec_pre @ [ "new" ])); - (VecPush, Sig.vec_push_sig, true, to_name (vec_pre @ [ "push" ])); - (VecInsert, Sig.vec_insert_sig, true, to_name (vec_pre @ [ "insert" ])); - (VecLen, Sig.vec_len_sig, false, to_name (vec_pre @ [ "len" ])); - ( VecIndex, - Sig.vec_index_shared_sig, - true, - to_name (index_pre @ [ "Index"; "index" ]) ); - ( VecIndexMut, - Sig.vec_index_mut_sig, - true, - to_name (index_pre @ [ "IndexMut"; "index_mut" ]) ); + to_name [ "alloc"; "boxed"; "Box"; "free" ], + Some [ true; false ] ); (* Array Index *) ( ArrayIndexShared, Sig.array_index_sig false, true, - to_name [ "@ArrayIndexShared" ] ); - (ArrayIndexMut, Sig.array_index_sig true, true, to_name [ "@ArrayIndexMut" ]); + to_name [ "@ArrayIndexShared" ], + None ); + ( ArrayIndexMut, + Sig.array_index_sig true, + true, + to_name [ "@ArrayIndexMut" ], + None ); (* Array to slice*) ( ArrayToSliceShared, Sig.array_to_slice_sig false, true, - to_name [ "@ArrayToSliceShared" ] ); + to_name [ "@ArrayToSliceShared" ], + None ); ( ArrayToSliceMut, Sig.array_to_slice_sig true, true, - to_name [ "@ArrayToSliceMut" ] ); - (* Array Subslice *) - ( ArraySubsliceShared, - Sig.array_subslice_sig false, - true, - to_name [ "@ArraySubsliceShared" ] ); - ( ArraySubsliceMut, - Sig.array_subslice_sig true, - true, - to_name [ "@ArraySubsliceMut" ] ); + to_name [ "@ArrayToSliceMut" ], + None ); + (* Array Repeat *) + (ArrayRepeat, Sig.array_repeat_sig, false, to_name [ "@ArrayRepeat" ], None); (* Slice Index *) ( SliceIndexShared, Sig.slice_index_sig false, true, - to_name [ "@SliceIndexShared" ] ); - (SliceIndexMut, Sig.slice_index_sig true, true, to_name [ "@SliceIndexMut" ]); - (* Slice Subslice *) - ( SliceSubsliceShared, - Sig.slice_subslice_sig false, - true, - to_name [ "@SliceSubsliceShared" ] ); - ( SliceSubsliceMut, - Sig.slice_subslice_sig true, + to_name [ "@SliceIndexShared" ], + None ); + ( SliceIndexMut, + Sig.slice_index_sig true, true, - to_name [ "@SliceSubsliceMut" ] ); - (SliceLen, Sig.slice_len_sig, false, to_name [ "@SliceLen" ]); + to_name [ "@SliceIndexMut" ], + None ); + (SliceLen, Sig.slice_len_sig, false, to_name [ "@SliceLen" ], None); ] -let get_assumed_info (id : A.assumed_fun_id) : assumed_info = - match List.find_opt (fun (id', _, _, _) -> id = id') assumed_infos with +let assumed_fun_infos : assumed_fun_info list = + List.map mk_assumed_fun_info raw_assumed_fun_infos + +let get_assumed_fun_info (id : A.assumed_fun_id) : assumed_fun_info = + match List.find_opt (fun x -> id = x.fun_id) assumed_fun_infos with | Some info -> info | None -> raise - (Failure ("get_assumed_info: not found: " ^ A.show_assumed_fun_id id)) + (Failure ("get_assumed_fun_info: not found: " ^ A.show_assumed_fun_id id)) -let get_assumed_sig (id : A.assumed_fun_id) : A.fun_sig = - let _, sg, _, _ = get_assumed_info id in - sg +let get_assumed_fun_sig (id : A.assumed_fun_id) : A.fun_sig = + (get_assumed_fun_info id).fun_sig -let get_assumed_name (id : A.assumed_fun_id) : fun_name = - let _, _, _, name = get_assumed_info id in - name +let get_assumed_fun_name (id : A.assumed_fun_id) : fun_name = + (get_assumed_fun_info id).name -let assumed_can_fail (id : A.assumed_fun_id) : bool = - let _, _, b, _ = get_assumed_info id in - b +let assumed_fun_can_fail (id : A.assumed_fun_id) : bool = + (get_assumed_fun_info id).can_fail diff --git a/compiler/Config.ml b/compiler/Config.ml index bd80769f..a487f9e2 100644 --- a/compiler/Config.ml +++ b/compiler/Config.ml @@ -124,7 +124,7 @@ let always_deconstruct_adts_with_matches = ref false (** Controls whether we need to use a state to model the external world (I/O, for instance). *) -let use_state = ref true +let use_state = ref false (** Controls whether we use fuel to control termination. *) @@ -160,7 +160,7 @@ let backward_no_state_update = ref false files for the types, clauses and functions, or if we group them in one file. *) -let split_files = ref true +let split_files = ref false (** Generate the library entry point, if the crate is split between different files. @@ -306,13 +306,6 @@ let filter_useless_monadic_calls = ref true *) let filter_useless_functions = ref true -(** Obsolete. TODO: remove. - - For Lean we used to parameterize the entire development by a section variable - called opaque_defs, of type OpaqueDefs. - *) -let wrap_opaque_in_sig = ref false - (** Use short names for the record fields. Some backends can't disambiguate records when their field names have collisions. @@ -323,3 +316,23 @@ let wrap_opaque_in_sig = ref false information), we use short names (i.e., the original field names). *) let record_fields_short_names = ref false + +(** Parameterize the traits with their associated types, so as not to use + types as first class objects. + + This is useful for some backends with limited expressiveness like HOL4, + and to account for type constraints (like [fn f<T : Foo>(...) where T::bar = usize]). + *) +let parameterize_trait_types = ref false + +(** For sanity check: type check the generated pure code (activates checks in + several places). + + TODO: deactivated for now because we need to implement the normalization of + trait associated types in the pure code. + *) +let type_check_pure_code = ref false + +(** Shall we fail hard if we encounter an issue, or should we attempt to go + as far as possible while leaving "holes" in the generated code? *) +let fail_hard = ref true diff --git a/compiler/Contexts.ml b/compiler/Contexts.ml index 2ca5653d..dac64a9a 100644 --- a/compiler/Contexts.ml +++ b/compiler/Contexts.ml @@ -5,6 +5,7 @@ open LlbcAst module V = Values open ValuesUtils open Identifiers +module L = Logging (** The [Id] module for dummy variables. @@ -17,6 +18,9 @@ IdGen () type dummy_var_id = DummyVarId.id [@@deriving show, ord] +(** The local logger *) +let log = L.contexts_log + (** Some global counters. Note that those counters were initially stored in {!eval_ctx} values, @@ -40,6 +44,7 @@ type dummy_var_id = DummyVarId.id [@@deriving show, ord] fn f x : fun_type = let id = fresh_id () in ... + fun () -> ... let g = f x in // <-- the fresh identifier gets generated here let x1 = g () in // <-- no fresh generation here @@ -250,27 +255,127 @@ type type_context = { } [@@deriving show] -type fun_context = { fun_decls : fun_decl FunDeclId.Map.t } [@@deriving show] +type fun_context = { + fun_decls : fun_decl FunDeclId.Map.t; + fun_infos : FunsAnalysis.fun_info FunDeclId.Map.t; +} +[@@deriving show] type global_context = { global_decls : global_decl GlobalDeclId.Map.t } [@@deriving show] +type trait_decls_context = { trait_decls : trait_decl TraitDeclId.Map.t } +[@@deriving show] + +type trait_impls_context = { trait_impls : trait_impl TraitImplId.Map.t } +[@@deriving show] + +type decls_ctx = { + type_ctx : type_context; + fun_ctx : fun_context; + global_ctx : global_context; + trait_decls_ctx : trait_decls_context; + trait_impls_ctx : trait_impls_context; +} +[@@deriving show] + +(** A reference to a trait associated type *) +type 'r trait_type_ref = { trait_ref : 'r trait_ref; type_name : string } +[@@deriving show, ord] + +type etrait_type_ref = erased_region trait_type_ref [@@deriving show, ord] + +type rtrait_type_ref = Types.RegionId.id Types.region trait_type_ref +[@@deriving show, ord] + +type strait_type_ref = Types.RegionVarId.id Types.region trait_type_ref +[@@deriving show, ord] + +(* TODO: correctly use the functors so as not to have a duplication below *) +module ETraitTypeRefOrd = struct + type t = etrait_type_ref + + let compare = compare_etrait_type_ref + let to_string = show_etrait_type_ref + let pp_t = pp_etrait_type_ref + let show_t = show_etrait_type_ref +end + +module RTraitTypeRefOrd = struct + type t = rtrait_type_ref + + let compare = compare_rtrait_type_ref + let to_string = show_rtrait_type_ref + let pp_t = pp_rtrait_type_ref + let show_t = show_rtrait_type_ref +end + +module STraitTypeRefOrd = struct + type t = strait_type_ref + + let compare = compare_strait_type_ref + let to_string = show_strait_type_ref + let pp_t = pp_strait_type_ref + let show_t = show_strait_type_ref +end + +module ETraitTypeRefMap = Collections.MakeMap (ETraitTypeRefOrd) +module RTraitTypeRefMap = Collections.MakeMap (RTraitTypeRefOrd) +module STraitTypeRefMap = Collections.MakeMap (STraitTypeRefOrd) + (** Evaluation context *) type eval_ctx = { type_context : type_context; fun_context : fun_context; global_context : global_context; + trait_decls_context : trait_decls_context; + trait_impls_context : trait_impls_context; region_groups : RegionGroupId.id list; type_vars : type_var list; const_generic_vars : const_generic_var list; + const_generic_vars_map : typed_value Types.ConstGenericVarId.Map.t; + (** The map from const generic vars to their values. Those values + can be symbolic values or concrete values (in the latter case: + if we run in interpreter mode) *) + norm_trait_etypes : ety ETraitTypeRefMap.t; + (** The normalized trait types (a map from trait types to their representatives). + Note that this doesn't support account higher-order types. *) + norm_trait_rtypes : rty RTraitTypeRefMap.t; + (** We need this because we manipulate two kinds of types. + Note that we actually forbid regions from appearing both in the trait + references and in the constraints given to the associated types, + meaning that we don't have to worry about mismatches due to changes + in region ids. + + TODO: how not to duplicate? + *) + norm_trait_stypes : sty STraitTypeRefMap.t; + (** We sometimes need to normalize types in non-instantiated signatures. + + Note that we either need to use the etypes/rtypes maps, or the stypes map. + This means that we either compute the maps for etypes and rtypes, or compute + the one for stypes (we don't always compute and carry all the maps). + *) env : env; ended_regions : RegionId.Set.t; } [@@deriving show] +let lookup_type_var_opt (ctx : eval_ctx) (vid : TypeVarId.id) : type_var option + = + if TypeVarId.to_int vid < List.length ctx.type_vars then + Some (TypeVarId.nth ctx.type_vars vid) + else None + let lookup_type_var (ctx : eval_ctx) (vid : TypeVarId.id) : type_var = TypeVarId.nth ctx.type_vars vid +let lookup_const_generic_var_opt (ctx : eval_ctx) (vid : ConstGenericVarId.id) : + const_generic_var option = + if ConstGenericVarId.to_int vid < List.length ctx.const_generic_vars then + Some (ConstGenericVarId.nth ctx.const_generic_vars vid) + else None + let lookup_const_generic_var (ctx : eval_ctx) (vid : ConstGenericVarId.id) : const_generic_var = ConstGenericVarId.nth ctx.const_generic_vars vid @@ -304,6 +409,12 @@ let ctx_lookup_global_decl (ctx : eval_ctx) (gid : GlobalDeclId.id) : global_decl = GlobalDeclId.Map.find gid ctx.global_context.global_decls +let ctx_lookup_trait_decl (ctx : eval_ctx) (id : TraitDeclId.id) : trait_decl = + TraitDeclId.Map.find id ctx.trait_decls_context.trait_decls + +let ctx_lookup_trait_impl (ctx : eval_ctx) (id : TraitImplId.id) : trait_impl = + TraitImplId.Map.find id ctx.trait_impls_context.trait_impls + (** Retrieve a variable's value in the current frame *) let env_lookup_var_value (env : env) (vid : VarId.id) : typed_value = snd (env_lookup_var env vid) @@ -312,6 +423,11 @@ let env_lookup_var_value (env : env) (vid : VarId.id) : typed_value = let ctx_lookup_var_value (ctx : eval_ctx) (vid : VarId.id) : typed_value = env_lookup_var_value ctx.env vid +(** Retrieve a const generic value in an evaluation context *) +let ctx_lookup_const_generic_value (ctx : eval_ctx) (vid : ConstGenericVarId.id) + : typed_value = + Types.ConstGenericVarId.Map.find vid ctx.const_generic_vars_map + (** Update a variable's value in the current frame. This is a helper function: it can break invariants and doesn't perform @@ -361,6 +477,15 @@ let ctx_push_var (ctx : eval_ctx) (var : var) (v : typed_value) : eval_ctx = *) let ctx_push_vars (ctx : eval_ctx) (vars : (var * typed_value) list) : eval_ctx = + log#ldebug + (lazy + ("push_vars:\n" + ^ String.concat "\n" + (List.map + (fun (var, value) -> + (* We can unfortunately not use Print because it depends on Contexts... *) + show_var var ^ " -> " ^ V.show_typed_value value) + vars))); assert ( List.for_all (fun (var, (value : typed_value)) -> var.var_ty = value.ty) diff --git a/compiler/Driver.ml b/compiler/Driver.ml index b646a53d..128ae890 100644 --- a/compiler/Driver.ml +++ b/compiler/Driver.ml @@ -17,11 +17,15 @@ let log = main_log let _ = (* Set up the logging - for now we use default values - TODO: use the * command-line arguments *) - (* By setting a level for the main_logger_handler, we filter everything *) + (* By setting a level for the main_logger_handler, we filter everything. + To have a good trace: one should switch between Info and Debug. + *) Easy_logging.Handlers.set_level main_logger_handler EL.Debug; main_log#set_level EL.Info; llbc_of_json_logger#set_level EL.Info; pre_passes_log#set_level EL.Info; + associated_types_log#set_level EL.Info; + contexts_log#set_level EL.Info; interpreter_log#set_level EL.Info; statements_log#set_level EL.Info; loops_match_ctxs_log#set_level EL.Info; @@ -37,7 +41,7 @@ let _ = pure_utils_log#set_level EL.Info; symbolic_to_pure_log#set_level EL.Info; pure_micro_passes_log#set_level EL.Info; - pure_to_extract_log#set_level EL.Info; + extract_log#set_level EL.Info; translate_log#set_level EL.Info; scc_log#set_level EL.Info; reorder_decls_log#set_level EL.Info @@ -62,6 +66,9 @@ let () = (* Read the command line arguments *) let dest_dir = ref "" in + (* Print the imported llbc *) + let print_llbc = ref false in + let spec = [ ( "-backend", @@ -86,9 +93,9 @@ let () = Arg.Set extract_decreases_clauses, " Use decreases clauses/termination measures for the recursive \ definitions" ); - ( "-no-state", - Arg.Clear use_state, - " Do not use state-error monads, simply use error monads" ); + ( "-state", + Arg.Set use_state, + " Use a *state*-error monads, instead of an error monads" ); ( "-use-fuel", Arg.Set use_fuel, " Use a fuel parameter to control divergence" ); @@ -99,10 +106,10 @@ let () = Arg.Set extract_template_decreases_clauses, " Generate templates for the required decreases clauses/termination \ measures, in a dedicated file. Implies -decreases-clauses" ); - ( "-no-split-files", - Arg.Clear split_files, - " Do not split the definitions between different files for types, \ - functions, etc." ); + ( "-split-files", + Arg.Set split_files, + " Split the definitions between different files for types, functions, \ + etc." ); ( "-no-check-inv", Arg.Clear check_invariants, " Deactivate the invariant sanity checks performed at every evaluation \ @@ -114,6 +121,8 @@ let () = ( "-lean-default-lakefile", Arg.Clear lean_gen_lakefile, " Generate a default lakefile.lean (Lean only)" ); + ("-print-llbc", Arg.Set print_llbc, " Print the imported LLBC"); + ("-k", Arg.Clear fail_hard, " Do not fail hard in case of error"); ] in @@ -127,6 +136,7 @@ let () = in if !extract_template_decreases_clauses then extract_decreases_clauses := true; + if !print_llbc then main_log#set_level EL.Debug; (* Sanity check (now that the arguments are parsed!): -template-clauses ==> decrease-clauses *) assert (!extract_decreases_clauses || not !extract_template_decreases_clauses); @@ -158,14 +168,14 @@ let () = | FStar -> (* Some patterns are not supported *) decompose_monadic_let_bindings := false; - decompose_nested_let_patterns := false + decompose_nested_let_patterns := false; + (* F* can disambiguate the field names *) + record_fields_short_names := true | Coq -> (* Some patterns are not supported *) decompose_monadic_let_bindings := true; decompose_nested_let_patterns := true | Lean -> - (* The Lean backend is experimental: print a warning *) - log#lwarning (lazy "The Lean backend is experimental"); (* We don't support fuel for the Lean backend *) if !use_fuel then ( log#error "The Lean backend doesn't support the -use-fuel option"; @@ -212,28 +222,6 @@ let () = log#linfo (lazy ("Imported: " ^ filename)); log#ldebug (lazy ("\n" ^ Print.Crate.crate_to_string m ^ "\n")); - (* Print a warning if the crate contains loops (loops are experimental for now) *) - let has_loops = - A.FunDeclId.Map.exists - (fun _ -> Aeneas.LlbcAstUtils.fun_decl_has_loops) - m.functions - in - if has_loops then log#lwarning (lazy "Support for loops is experimental"); - - (* If we target Lean, we request the crates to be split into several files - whenever there are opaque functions *) - if - !backend = Lean - && A.FunDeclId.Map.exists - (fun _ (d : A.fun_decl) -> d.body = None) - m.functions - && not !split_files - then ( - log#error - "For Lean, we request the -split-file option whenever using opaque \ - functions"; - fail ()); - (* We don't support mutually recursive definitions with decreases clauses in Lean *) if !backend = Lean && !extract_decreases_clauses diff --git a/compiler/Extract.ml b/compiler/Extract.ml index c4238d83..d04f5c1d 100644 --- a/compiler/Extract.ml +++ b/compiler/Extract.ml @@ -3,2102 +3,104 @@ the formatter everywhere... *) -open Utils open Pure open PureUtils open TranslateCore open ExtractBase -open StringUtils open Config -module F = Format - -(** Small helper to compute the name of an int type *) -let int_name (int_ty : integer_type) = - let isize, usize, i_format, u_format = - match !backend with - | FStar | Coq | HOL4 -> - ("isize", "usize", format_of_string "i%d", format_of_string "u%d") - | Lean -> ("Isize", "Usize", format_of_string "I%d", format_of_string "U%d") - in - match int_ty with - | Isize -> isize - | I8 -> Printf.sprintf i_format 8 - | I16 -> Printf.sprintf i_format 16 - | I32 -> Printf.sprintf i_format 32 - | I64 -> Printf.sprintf i_format 64 - | I128 -> Printf.sprintf i_format 128 - | Usize -> usize - | U8 -> Printf.sprintf u_format 8 - | U16 -> Printf.sprintf u_format 16 - | U32 -> Printf.sprintf u_format 32 - | U64 -> Printf.sprintf u_format 64 - | U128 -> Printf.sprintf u_format 128 - -(** Small helper to compute the name of a unary operation *) -let unop_name (unop : unop) : string = - match unop with - | Not -> ( - match !backend with FStar | Lean -> "not" | Coq -> "negb" | HOL4 -> "~") - | Neg (int_ty : integer_type) -> ( - match !backend with Lean -> "-" | _ -> int_name int_ty ^ "_neg") - | Cast _ -> - (* We never directly use the unop name in this case *) - raise (Failure "Unsupported") - -(** Small helper to compute the name of a binary operation (note that many - binary operations like "less than" are extracted to primitive operations, - like [<]). - *) -let named_binop_name (binop : E.binop) (int_ty : integer_type) : string = - let binop = - match binop with - | Div -> "div" - | Rem -> "rem" - | Add -> "add" - | Sub -> "sub" - | Mul -> "mul" - | Lt -> "lt" - | Le -> "le" - | Ge -> "ge" - | Gt -> "gt" - | _ -> raise (Failure "Unreachable") - in - (* Remark: the Lean case is actually not used *) - match !backend with - | Lean -> int_name int_ty ^ "." ^ binop - | FStar | Coq | HOL4 -> int_name int_ty ^ "_" ^ binop - -(** A list of keywords/identifiers used by the backend and with which we - want to check collision. - - Remark: this is useful mostly to look for collisions when generating - names for *variables*. - *) -let keywords () = - let named_unops = - unop_name Not - :: List.map (fun it -> unop_name (Neg it)) T.all_signed_int_types - in - let named_binops = [ E.Div; Rem; Add; Sub; Mul ] in - let named_binops = - List.concat_map - (fun bn -> List.map (fun it -> named_binop_name bn it) T.all_int_types) - named_binops - in - let misc = - match !backend with - | FStar -> - [ - "assert"; - "assert_norm"; - "assume"; - "else"; - "fun"; - "fn"; - "FStar"; - "FStar.Mul"; - "if"; - "in"; - "include"; - "int"; - "let"; - "list"; - "match"; - "not"; - "open"; - "rec"; - "scalar_cast"; - "then"; - "type"; - "Type0"; - "Type"; - "unit"; - "val"; - "with"; - ] - | Coq -> - [ - "assert"; - "Arguments"; - "Axiom"; - "char_of_byte"; - "Check"; - "Declare"; - "Definition"; - "else"; - "End"; - "fun"; - "Fixpoint"; - "if"; - "in"; - "int"; - "Inductive"; - "Import"; - "let"; - "Lemma"; - "match"; - "Module"; - "not"; - "Notation"; - "Proof"; - "Qed"; - "rec"; - "Record"; - "Require"; - "Scope"; - "Search"; - "SearchPattern"; - "Set"; - "then"; - (* [tt] is unit *) - "tt"; - "type"; - "Type"; - "unit"; - "with"; - ] - | Lean -> - [ - "by"; - "class"; - "decreasing_by"; - "def"; - "deriving"; - "do"; - "else"; - "end"; - "for"; - "have"; - "if"; - "inductive"; - "instance"; - "import"; - "let"; - "macro"; - "match"; - "namespace"; - "opaque"; - "open"; - "run_cmd"; - "set_option"; - "simp"; - "structure"; - "syntax"; - "termination_by"; - "then"; - "Type"; - "unsafe"; - "where"; - "with"; - "opaque_defs"; - ] - | HOL4 -> - [ - "Axiom"; - "case"; - "Definition"; - "else"; - "End"; - "fix"; - "fix_exec"; - "fn"; - "fun"; - "if"; - "in"; - "int"; - "Inductive"; - "let"; - "of"; - "Proof"; - "QED"; - "then"; - "Theorem"; - ] - in - List.concat [ named_unops; named_binops; misc ] - -let assumed_adts () : (assumed_ty * string) list = - match !backend with - | Lean -> - [ - (State, "State"); - (Result, "Result"); - (Error, "Error"); - (Fuel, "Nat"); - (Option, "Option"); - (Vec, "Vec"); - (Array, "Array"); - (Slice, "Slice"); - (Str, "Str"); - (Range, "Range"); - ] - | Coq | FStar -> - [ - (State, "state"); - (Result, "result"); - (Error, "error"); - (Fuel, "nat"); - (Option, "option"); - (Vec, "vec"); - (Array, "array"); - (Slice, "slice"); - (Str, "str"); - (Range, "range"); - ] - | HOL4 -> - [ - (State, "state"); - (Result, "result"); - (Error, "error"); - (Fuel, "num"); - (Option, "option"); - (Vec, "vec"); - (Array, "array"); - (Slice, "slice"); - (Str, "str"); - (Range, "range"); - ] - -let assumed_struct_constructors () : (assumed_ty * string) list = - match !backend with - | Lean -> [ (Range, "Range.mk"); (Array, "Array.make") ] - | Coq -> [ (Range, "mk_range"); (Array, "mk_array") ] - | FStar -> [ (Range, "Mkrange"); (Array, "mk_array") ] - | HOL4 -> [ (Range, "mk_range"); (Array, "mk_array") ] - -let assumed_variants () : (assumed_ty * VariantId.id * string) list = - match !backend with - | FStar -> - [ - (Result, result_return_id, "Return"); - (Result, result_fail_id, "Fail"); - (Error, error_failure_id, "Failure"); - (Error, error_out_of_fuel_id, "OutOfFuel"); - (* No Fuel::Zero on purpose *) - (* No Fuel::Succ on purpose *) - (Option, option_some_id, "Some"); - (Option, option_none_id, "None"); - ] - | Coq -> - [ - (Result, result_return_id, "Return"); - (Result, result_fail_id, "Fail_"); - (Error, error_failure_id, "Failure"); - (Error, error_out_of_fuel_id, "OutOfFuel"); - (Fuel, fuel_zero_id, "O"); - (Fuel, fuel_succ_id, "S"); - (Option, option_some_id, "Some"); - (Option, option_none_id, "None"); - ] - | Lean -> - [ - (Result, result_return_id, "ret"); - (Result, result_fail_id, "fail"); - (Error, error_failure_id, "panic"); - (* No Fuel::Zero on purpose *) - (* No Fuel::Succ on purpose *) - (Option, option_some_id, "some"); - (Option, option_none_id, "none"); - ] - | HOL4 -> - [ - (Result, result_return_id, "Return"); - (Result, result_fail_id, "Fail"); - (Error, error_failure_id, "Failure"); - (* No Fuel::Zero on purpose *) - (* No Fuel::Succ on purpose *) - (Option, option_some_id, "SOME"); - (Option, option_none_id, "NONE"); - ] - -let assumed_llbc_functions () : - (A.assumed_fun_id * T.RegionGroupId.id option * string) list = - let rg0 = Some T.RegionGroupId.zero in - match !backend with - | FStar | Coq | HOL4 -> - [ - (Replace, None, "mem_replace_fwd"); - (Replace, rg0, "mem_replace_back"); - (VecNew, None, "vec_new"); - (VecPush, None, "vec_push_fwd") (* Shouldn't be used *); - (VecPush, rg0, "vec_push_back"); - (VecInsert, None, "vec_insert_fwd") (* Shouldn't be used *); - (VecInsert, rg0, "vec_insert_back"); - (VecLen, None, "vec_len"); - (VecIndex, None, "vec_index_fwd"); - (VecIndex, rg0, "vec_index_back") (* shouldn't be used *); - (VecIndexMut, None, "vec_index_mut_fwd"); - (VecIndexMut, rg0, "vec_index_mut_back"); - (ArrayIndexShared, None, "array_index_shared"); - (ArrayIndexMut, None, "array_index_mut_fwd"); - (ArrayIndexMut, rg0, "array_index_mut_back"); - (ArrayToSliceShared, None, "array_to_slice_shared"); - (ArrayToSliceMut, None, "array_to_slice_mut_fwd"); - (ArrayToSliceMut, rg0, "array_to_slice_mut_back"); - (ArraySubsliceShared, None, "array_subslice_shared"); - (ArraySubsliceMut, None, "array_subslice_mut_fwd"); - (ArraySubsliceMut, rg0, "array_subslice_mut_back"); - (SliceIndexShared, None, "slice_index_shared"); - (SliceIndexMut, None, "slice_index_mut_fwd"); - (SliceIndexMut, rg0, "slice_index_mut_back"); - (SliceSubsliceShared, None, "slice_subslice_shared"); - (SliceSubsliceMut, None, "slice_subslice_mut_fwd"); - (SliceSubsliceMut, rg0, "slice_subslice_mut_back"); - (SliceLen, None, "slice_len"); - ] - | Lean -> - [ - (Replace, None, "mem.replace"); - (Replace, rg0, "mem.replace_back"); - (VecNew, None, "Vec.new"); - (VecPush, None, "Vec.push_fwd") (* Shouldn't be used *); - (VecPush, rg0, "Vec.push"); - (VecInsert, None, "Vec.insert_fwd") (* Shouldn't be used *); - (VecInsert, rg0, "Vec.insert"); - (VecLen, None, "Vec.len"); - (VecIndex, None, "Vec.index_shared"); - (VecIndex, rg0, "Vec.index_shared_back") (* shouldn't be used *); - (VecIndexMut, None, "Vec.index_mut"); - (VecIndexMut, rg0, "Vec.index_mut_back"); - (ArrayIndexShared, None, "Array.index_shared"); - (ArrayIndexMut, None, "Array.index_mut"); - (ArrayIndexMut, rg0, "Array.index_mut_back"); - (ArrayToSliceShared, None, "Array.to_slice_shared"); - (ArrayToSliceMut, None, "Array.to_slice_mut"); - (ArrayToSliceMut, rg0, "Array.to_slice_mut_back"); - (ArraySubsliceShared, None, "Array.subslice_shared"); - (ArraySubsliceMut, None, "Array.subslice_mut"); - (ArraySubsliceMut, rg0, "Array.subslice_mut_back"); - (SliceIndexShared, None, "Slice.index_shared"); - (SliceIndexMut, None, "Slice.index_mut"); - (SliceIndexMut, rg0, "Slice.index_mut_back"); - (SliceSubsliceShared, None, "Slice.subslice_shared"); - (SliceSubsliceMut, None, "Slice.subslice_mut"); - (SliceSubsliceMut, rg0, "Slice.subslice_mut_back"); - (SliceLen, None, "Slice.len"); - ] - -let assumed_pure_functions () : (pure_assumed_fun_id * string) list = - match !backend with - | FStar -> - [ - (Return, "return"); - (Fail, "fail"); - (Assert, "massert"); - (FuelDecrease, "decrease"); - (FuelEqZero, "is_zero"); - ] - | Coq -> - (* We don't provide [FuelDecrease] and [FuelEqZero] on purpose *) - [ (Return, "return_"); (Fail, "fail_"); (Assert, "massert") ] - | Lean -> - (* We don't provide [FuelDecrease] and [FuelEqZero] on purpose *) - [ (Return, "return"); (Fail, "fail_"); (Assert, "massert") ] - | HOL4 -> - (* We don't provide [FuelDecrease] and [FuelEqZero] on purpose *) - [ (Return, "return"); (Fail, "fail"); (Assert, "massert") ] - -let names_map_init () : names_map_init = - { - keywords = keywords (); - assumed_adts = assumed_adts (); - assumed_structs = assumed_struct_constructors (); - assumed_variants = assumed_variants (); - assumed_llbc_functions = assumed_llbc_functions (); - assumed_pure_functions = assumed_pure_functions (); - } - -let extract_unop (extract_expr : bool -> texpression -> unit) - (fmt : F.formatter) (inside : bool) (unop : unop) (arg : texpression) : unit - = - match unop with - | Not | Neg _ -> - let unop = unop_name unop in - if inside then F.pp_print_string fmt "("; - F.pp_print_string fmt unop; - F.pp_print_space fmt (); - extract_expr true arg; - if inside then F.pp_print_string fmt ")" - | Cast (src, tgt) -> ( - (* HOL4 has a special treatment: because it doesn't support dependent - types, we don't have a specific operator for the cast *) - match !backend with - | HOL4 -> - (* Casting, say, an u32 to an i32 would be done as follows: - {[ - mk_i32 (u32_to_int x) - ]} - *) - if inside then F.pp_print_string fmt "("; - F.pp_print_string fmt ("mk_" ^ int_name tgt); - F.pp_print_space fmt (); - F.pp_print_string fmt "("; - F.pp_print_string fmt (int_name src ^ "_to_int"); - F.pp_print_space fmt (); - extract_expr true arg; - F.pp_print_string fmt ")"; - if inside then F.pp_print_string fmt ")" - | FStar | Coq | Lean -> - (* Rem.: the source type is an implicit parameter *) - if inside then F.pp_print_string fmt "("; - let cast_str = - match !backend with - | Coq | FStar -> "scalar_cast" - | Lean -> (* TODO: I8.cast, I16.cast, etc.*) "Scalar.cast" - | HOL4 -> raise (Failure "Unreachable") - in - F.pp_print_string fmt cast_str; - F.pp_print_space fmt (); - if !backend <> Lean then ( - F.pp_print_string fmt - (StringUtils.capitalize_first_letter - (PrintPure.integer_type_to_string src)); - F.pp_print_space fmt ()); - if !backend = Lean then F.pp_print_string fmt ("." ^ int_name tgt) - else - F.pp_print_string fmt - (StringUtils.capitalize_first_letter - (PrintPure.integer_type_to_string tgt)); - F.pp_print_space fmt (); - extract_expr true arg; - if inside then F.pp_print_string fmt ")") - -(** [extract_expr] : the boolean argument is [inside] *) -let extract_binop (extract_expr : bool -> texpression -> unit) - (fmt : F.formatter) (inside : bool) (binop : E.binop) - (int_ty : integer_type) (arg0 : texpression) (arg1 : texpression) : unit = - if inside then F.pp_print_string fmt "("; - (* Some binary operations have a special notation depending on the backend *) - (match (!backend, binop) with - | HOL4, (Eq | Ne) - | (FStar | Coq | Lean), (Eq | Lt | Le | Ne | Ge | Gt) - | Lean, (Div | Rem | Add | Sub | Mul) -> - let binop = - match binop with - | Eq -> "=" - | Lt -> "<" - | Le -> "<=" - | Ne -> if !backend = Lean then "!=" else "<>" - | Ge -> ">=" - | Gt -> ">" - | Div -> "/" - | Rem -> "%" - | Add -> "+" - | Sub -> "-" - | Mul -> "*" - | _ -> raise (Failure "Unreachable") - in - let binop = - match !backend with FStar | Lean | HOL4 -> binop | Coq -> "s" ^ binop - in - extract_expr false arg0; - F.pp_print_space fmt (); - F.pp_print_string fmt binop; - F.pp_print_space fmt (); - extract_expr false arg1 - | _, (Lt | Le | Ge | Gt | Div | Rem | Add | Sub | Mul) -> - let binop = named_binop_name binop int_ty in - F.pp_print_string fmt binop; - F.pp_print_space fmt (); - extract_expr true arg0; - F.pp_print_space fmt (); - extract_expr true arg1 - | _, (BitXor | BitAnd | BitOr | Shl | Shr) -> raise Unimplemented); - if inside then F.pp_print_string fmt ")" - -let type_decl_kind_to_qualif (kind : decl_kind) - (type_kind : type_decl_kind option) : string option = - match !backend with - | FStar -> ( - match kind with - | SingleNonRec -> Some "type" - | SingleRec -> Some "type" - | MutRecFirst -> Some "type" - | MutRecInner -> Some "and" - | MutRecLast -> Some "and" - | Assumed -> Some "assume type" - | Declared -> Some "val") - | Coq -> ( - match (kind, type_kind) with - | SingleNonRec, Some Enum -> Some "Inductive" - | SingleNonRec, Some Struct -> Some "Record" - | (SingleRec | MutRecFirst), Some _ -> Some "Inductive" - | (MutRecInner | MutRecLast), Some _ -> - (* Coq doesn't support groups of mutually recursive definitions which mix - * records and inducties: we convert everything to records if this happens - *) - Some "with" - | (Assumed | Declared), None -> Some "Axiom" - | _ -> raise (Failure "Unexpected")) - | Lean -> ( - match kind with - | SingleNonRec -> - if type_kind = Some Struct then Some "structure" else Some "inductive" - | SingleRec -> Some "inductive" - | MutRecFirst -> Some "inductive" - | MutRecInner -> Some "inductive" - | MutRecLast -> Some "inductive" - | Assumed -> Some "axiom" - | Declared -> Some "axiom") - | HOL4 -> None - -let fun_decl_kind_to_qualif (kind : decl_kind) : string option = - match !backend with - | FStar -> ( - match kind with - | SingleNonRec -> Some "let" - | SingleRec -> Some "let rec" - | MutRecFirst -> Some "let rec" - | MutRecInner -> Some "and" - | MutRecLast -> Some "and" - | Assumed -> Some "assume val" - | Declared -> Some "val") - | Coq -> ( - match kind with - | SingleNonRec -> Some "Definition" - | SingleRec -> Some "Fixpoint" - | MutRecFirst -> Some "Fixpoint" - | MutRecInner -> Some "with" - | MutRecLast -> Some "with" - | Assumed -> Some "Axiom" - | Declared -> Some "Axiom") - | Lean -> ( - match kind with - | SingleNonRec -> Some "def" - | SingleRec -> Some "divergent def" - | MutRecFirst -> Some "mutual divergent def" - | MutRecInner -> Some "divergent def" - | MutRecLast -> Some "divergent def" - | Assumed -> Some "axiom" - | Declared -> Some "axiom") - | HOL4 -> None - -(** The type of types. - - TODO: move inside the formatter? - *) -let type_keyword () = - match !backend with - | FStar -> "Type0" - | Coq | Lean -> "Type" - | HOL4 -> raise (Failure "Unexpected") - -(** - [ctx]: we use the context to lookup type definitions, to retrieve type names. - This is used to compute variable names, when they have no basenames: in this - case we use the first letter of the type name. - - [variant_concatenate_type_name]: if true, add the type name as a prefix - to the variant names. - Ex.: - In Rust: - {[ - enum List = { - Cons(u32, Box<List>),x - Nil, - } - ]} - - F*, if option activated: - {[ - type list = - | ListCons : u32 -> list -> list - | ListNil : list - ]} - - F*, if option not activated: - {[ - type list = - | Cons : u32 -> list -> list - | Nil : list - ]} - - Rk.: this should be true by default, because in Rust all the variant names - are actively uniquely identifier by the type name [List::Cons(...)], while - in other languages it is not necessarily the case, and thus clashes can mess - up type checking. Note that some languages actually forbids the name clashes - (it is the case of F* ). - *) -let mk_formatter (ctx : trans_ctx) (crate_name : string) - (variant_concatenate_type_name : bool) : formatter = - let int_name = int_name in - - (* Prepare a name. - * The first id elem is always the crate: if it is the local crate, - * we remove it. - * We also remove all the disambiguators, then convert everything to strings. - * **Rmk:** because we remove the disambiguators, there may be name collisions - * (which is ok, because we check for name collisions and fail if there is any). - *) - let get_name (name : name) : string list = - (* Rmk.: initially we only filtered the disambiguators equal to 0 *) - let name = Names.filter_disambiguators name in - match name with - | Ident crate :: name -> - let name = if crate = crate_name then name else Ident crate :: name in - let name = - List.map - (function - | Names.Ident s -> s - | Disambiguator d -> Names.Disambiguator.to_string d) - name - in - name - | _ -> - raise (Failure ("Unexpected name shape: " ^ Print.name_to_string name)) - in - let get_type_name = get_name in - let type_name_to_camel_case name = - let name = get_type_name name in - let name = List.map to_camel_case name in - String.concat "" name - in - let type_name_to_snake_case name = - let name = get_type_name name in - let name = List.map to_snake_case name in - let name = String.concat "_" name in - match !backend with - | FStar | Lean | HOL4 -> name - | Coq -> capitalize_first_letter name - in - let type_name name = - match !backend with - | FStar | Coq | HOL4 -> type_name_to_snake_case name ^ "_t" - | Lean -> String.concat "." (get_type_name name) - in - let field_name (def_name : name) (field_id : FieldId.id) - (field_name : string option) : string = - let field_name = - match field_name with - | Some field_name -> field_name - | None -> FieldId.to_string field_id - in - if !Config.record_fields_short_names then field_name - else - let def_name = type_name_to_snake_case def_name ^ "_" in - def_name ^ field_name - in - let variant_name (def_name : name) (variant : string) : string = - match !backend with - | FStar | Coq | HOL4 -> - let variant = to_camel_case variant in - if variant_concatenate_type_name then - type_name_to_camel_case def_name ^ variant - else variant - | Lean -> variant - in - let struct_constructor (basename : name) : string = - let tname = type_name basename in - let prefix = - match !backend with FStar -> "Mk" | Coq | HOL4 -> "mk" | Lean -> "" - in - let suffix = - match !backend with FStar | Coq | HOL4 -> "" | Lean -> ".mk" - in - prefix ^ tname ^ suffix - in - let get_fun_name fname = - let fname = get_name fname in - (* TODO: don't convert to snake case for Coq, HOL4, F* *) - match !backend with - | FStar | Coq | HOL4 -> String.concat "_" (List.map to_snake_case fname) - | Lean -> String.concat "." fname - in - let global_name (name : global_name) : string = - (* Converting to snake case also lowercases the letters (in Rust, global - * names are written in capital letters). *) - let parts = List.map to_snake_case (get_name name) in - String.concat "_" parts - in - let fun_name (fname : fun_name) (num_loops : int) (loop_id : LoopId.id option) - (num_rgs : int) (rg : region_group_info option) (filter_info : bool * int) - : string = - let fname = get_fun_name fname in - (* Compute the suffix *) - let suffix = default_fun_suffix num_loops loop_id num_rgs rg filter_info in - (* Concatenate *) - fname ^ suffix - in - - let termination_measure_name (_fid : A.FunDeclId.id) (fname : fun_name) - (num_loops : int) (loop_id : LoopId.id option) : string = - let fname = get_fun_name fname in - let lp_suffix = default_fun_loop_suffix num_loops loop_id in - (* Compute the suffix *) - let suffix = - match !Config.backend with - | FStar -> "_decreases" - | Lean -> "_terminates" - | Coq | HOL4 -> raise (Failure "Unexpected") - in - (* Concatenate *) - fname ^ lp_suffix ^ suffix - in - - let decreases_proof_name (_fid : A.FunDeclId.id) (fname : fun_name) - (num_loops : int) (loop_id : LoopId.id option) : string = - let fname = get_fun_name fname in - let lp_suffix = default_fun_loop_suffix num_loops loop_id in - (* Compute the suffix *) - let suffix = - match !Config.backend with - | Lean -> "_decreases" - | FStar | Coq | HOL4 -> raise (Failure "Unexpected") - in - (* Concatenate *) - fname ^ lp_suffix ^ suffix - in - - let opaque_pre () = - match !Config.backend with - | FStar | Coq | HOL4 -> "" - | Lean -> if !Config.wrap_opaque_in_sig then "opaque_defs." else "" - in - - let var_basename (_varset : StringSet.t) (basename : string option) (ty : ty) - : string = - (* If there is a basename, we use it *) - match basename with - | Some basename -> - (* This should be a no-op *) - to_snake_case basename - | None -> ( - (* No basename: we use the first letter of the type *) - match ty with - | Adt (type_id, tys, _) -> ( - match type_id with - | Tuple -> - (* The "pair" case is frequent enough to have its special treatment *) - if List.length tys = 2 then "p" else "t" - | Assumed Result -> "r" - | Assumed Error -> ConstStrings.error_basename - | Assumed Fuel -> ConstStrings.fuel_basename - | Assumed Option -> "opt" - | Assumed Vec -> "v" - | Assumed Array -> "a" - | Assumed Slice -> "s" - | Assumed Str -> "s" - | Assumed Range -> "r" - | Assumed State -> ConstStrings.state_basename - | AdtId adt_id -> - let def = - TypeDeclId.Map.find adt_id ctx.type_context.type_decls - in - (* We do the following: - * - compute the type name, and retrieve the last ident - * - convert this to snake case - * - take the first letter of every "letter group" - * Ex.: ["hashmap"; "HashMap"] ~~> "HashMap" -> "hash_map" -> "hm" - *) - (* Thename shouldn't be empty, and its last element should - * be an ident *) - let cl = List.nth def.name (List.length def.name - 1) in - let cl = to_snake_case (Names.as_ident cl) in - let cl = String.split_on_char '_' cl in - let cl = List.filter (fun s -> String.length s > 0) cl in - assert (List.length cl > 0); - let cl = List.map (fun s -> s.[0]) cl in - StringUtils.string_of_chars cl) - | TypeVar _ -> ( - (* TODO: use "t" also for F* *) - match !backend with - | FStar -> "x" (* lacking inspiration here... *) - | Coq | Lean | HOL4 -> "t" (* lacking inspiration here... *)) - | Literal lty -> ( - match lty with Bool -> "b" | Char -> "c" | Integer _ -> "i") - | Arrow _ -> "f") - in - let type_var_basename (_varset : StringSet.t) (basename : string) : string = - (* Rust type variables are snake-case and start with a capital letter *) - match !backend with - | FStar -> - (* This is *not* a no-op: this removes the capital letter *) - to_snake_case basename - | HOL4 -> - (* In HOL4, type variable names must start with "'" *) - "'" ^ to_snake_case basename - | Coq | Lean -> basename - in - let const_generic_var_basename (_varset : StringSet.t) (basename : string) : - string = - (* Rust type variables are snake-case and start with a capital letter *) - match !backend with - | FStar | HOL4 -> - (* This is *not* a no-op: this removes the capital letter *) - to_snake_case basename - | Coq | Lean -> basename - in - let append_index (basename : string) (i : int) : string = - basename ^ string_of_int i - in - - let extract_literal (fmt : F.formatter) (inside : bool) (cv : literal) : unit - = - match cv with - | Scalar sv -> ( - match !backend with - | FStar -> F.pp_print_string fmt (Z.to_string sv.PV.value) - | Coq | HOL4 -> - let print_brackets = inside && !backend = HOL4 in - if print_brackets then F.pp_print_string fmt "("; - (match !backend with - | Coq -> () - | HOL4 -> - F.pp_print_string fmt ("int_to_" ^ int_name sv.PV.int_ty); - F.pp_print_space fmt () - | _ -> raise (Failure "Unreachable")); - (* We need to add parentheses if the value is negative *) - if sv.PV.value >= Z.of_int 0 then - F.pp_print_string fmt (Z.to_string sv.PV.value) - else F.pp_print_string fmt ("(" ^ Z.to_string sv.PV.value ^ ")"); - (match !backend with - | Coq -> F.pp_print_string fmt ("%" ^ int_name sv.PV.int_ty) - | HOL4 -> () - | _ -> raise (Failure "Unreachable")); - if print_brackets then F.pp_print_string fmt ")" - | Lean -> - F.pp_print_string fmt "("; - F.pp_print_string fmt (int_name sv.int_ty); - F.pp_print_string fmt ".ofInt "; - (* Something very annoying: negated values like `-3` are - ambiguous in Lean because of conversions, so we have to - be extremely explicit with negative numbers. - *) - if Z.lt sv.value Z.zero then ( - F.pp_print_string fmt "("; - F.pp_print_string fmt "-"; - F.pp_print_string fmt "("; - Z.pp_print fmt (Z.neg sv.value); - F.pp_print_string fmt ":Int"; - F.pp_print_string fmt ")"; - F.pp_print_string fmt ")") - else Z.pp_print fmt sv.value; - F.pp_print_string fmt ")") - | Bool b -> - let b = - match !backend with - | HOL4 -> if b then "T" else "F" - | Coq | FStar | Lean -> if b then "true" else "false" - in - F.pp_print_string fmt b - | Char c -> ( - match !backend with - | HOL4 -> - (* [#"a"] is a notation for [CHR 97] (97 is the ASCII code for 'a') *) - F.pp_print_string fmt ("#\"" ^ String.make 1 c ^ "\"") - | FStar | Lean -> F.pp_print_string fmt ("'" ^ String.make 1 c ^ "'") - | Coq -> - if inside then F.pp_print_string fmt "("; - F.pp_print_string fmt "char_of_byte"; - F.pp_print_space fmt (); - (* Convert the the char to ascii *) - let c = - let i = Char.code c in - let x0 = i / 16 in - let x1 = i mod 16 in - "Coq.Init.Byte.x" ^ string_of_int x0 ^ string_of_int x1 - in - F.pp_print_string fmt c; - if inside then F.pp_print_string fmt ")") - in - let bool_name = if !backend = Lean then "Bool" else "bool" in - let char_name = if !backend = Lean then "Char" else "char" in - let str_name = if !backend = Lean then "String" else "string" in - { - bool_name; - char_name; - int_name; - str_name; - type_decl_kind_to_qualif; - fun_decl_kind_to_qualif; - field_name; - variant_name; - struct_constructor; - type_name; - global_name; - fun_name; - termination_measure_name; - decreases_proof_name; - opaque_pre; - var_basename; - type_var_basename; - const_generic_var_basename; - append_index; - extract_literal; - extract_unop; - extract_binop; - } - -let mk_formatter_and_names_map (ctx : trans_ctx) (crate_name : string) - (variant_concatenate_type_name : bool) : formatter * names_map = - let fmt = mk_formatter ctx crate_name variant_concatenate_type_name in - let names_map = initialize_names_map fmt (names_map_init ()) in - (fmt, names_map) - -let is_single_opaque_fun_decl_group (dg : Pure.fun_decl list) : bool = - match dg with [ d ] -> d.body = None | _ -> false - -let is_single_opaque_type_decl_group (dg : Pure.type_decl list) : bool = - match dg with [ d ] -> d.kind = Opaque | _ -> false - -let is_empty_record_type_decl (d : Pure.type_decl) : bool = d.kind = Struct [] - -let is_empty_record_type_decl_group (dg : Pure.type_decl list) : bool = - match dg with [ d ] -> is_empty_record_type_decl d | _ -> false - -(** In some provers, groups of definitions must be delimited. - - - in Coq, *every* group (including singletons) must end with "." - - in Lean, groups of mutually recursive definitions must end with "end" - - in HOL4 (in most situations) the whole group must be within a `Define` command - - Calls to {!extract_fun_decl} should be inserted between calls to - {!start_fun_decl_group} and {!end_fun_decl_group}. - - TODO: maybe those [{start/end}_decl_group] functions are not that much a good - idea and we should merge them with the corresponding [extract_decl] functions. - *) -let start_fun_decl_group (ctx : extraction_ctx) (fmt : F.formatter) - (is_rec : bool) (dg : Pure.fun_decl list) = - match !backend with - | FStar | Coq | Lean -> () - | HOL4 -> - (* In HOL4, opaque functions have a special treatment *) - if is_single_opaque_fun_decl_group dg then () - else - let with_opaque_pre = false in - let compute_fun_def_name (def : Pure.fun_decl) : string = - ctx_get_local_function with_opaque_pre def.def_id def.loop_id - def.back_id ctx - ^ "_def" - in - let names = List.map compute_fun_def_name dg in - (* Add a break before *) - F.pp_print_break fmt 0 0; - (* Open the box for the delimiters *) - F.pp_open_vbox fmt 0; - (* Open the box for the definitions themselves *) - F.pp_open_vbox fmt ctx.indent_incr; - (* Print the delimiters *) - if is_rec then - F.pp_print_string fmt - ("val [" ^ String.concat ", " names ^ "] = DefineDiv ‘") - else ( - assert (List.length names = 1); - let name = List.hd names in - F.pp_print_string fmt ("val " ^ name ^ " = Define ‘")); - F.pp_print_cut fmt () - -(** See {!start_fun_decl_group}. *) -let end_fun_decl_group (fmt : F.formatter) (is_rec : bool) - (dg : Pure.fun_decl list) = - match !backend with - | FStar -> () - | Coq -> - (* For aesthetic reasons, we print the Coq end group delimiter directly - in {!extract_fun_decl}. *) - () - | Lean -> - (* We must add the "end" keyword to groups of mutually recursive functions *) - if is_rec && List.length dg > 1 then ( - F.pp_print_cut fmt (); - F.pp_print_string fmt "end"; - (* Add breaks to insert new lines between definitions *) - F.pp_print_break fmt 0 0) - else () - | HOL4 -> - (* In HOL4, opaque functions have a special treatment *) - if is_single_opaque_fun_decl_group dg then () - else ( - (* Close the box for the definitions *) - F.pp_close_box fmt (); - (* Print the end delimiter *) - F.pp_print_cut fmt (); - F.pp_print_string fmt "’"; - (* Close the box for the delimiters *) - F.pp_close_box fmt (); - (* Add breaks to insert new lines between definitions *) - F.pp_print_break fmt 0 0) - -(** See {!start_fun_decl_group}: similar usage, but for the type declarations. *) -let start_type_decl_group (ctx : extraction_ctx) (fmt : F.formatter) - (is_rec : bool) (dg : Pure.type_decl list) = - match !backend with - | FStar | Coq -> () - | Lean -> - if is_rec && List.length dg > 1 then ( - F.pp_print_space fmt (); - F.pp_print_string fmt "mutual"; - F.pp_print_space fmt ()) - | HOL4 -> - (* In HOL4, opaque types and empty records have a special treatment *) - if - is_single_opaque_type_decl_group dg - || is_empty_record_type_decl_group dg - then () - else ( - (* Add a break before *) - F.pp_print_break fmt 0 0; - (* Open the box for the delimiters *) - F.pp_open_vbox fmt 0; - (* Open the box for the definitions themselves *) - F.pp_open_vbox fmt ctx.indent_incr; - (* Print the delimiters *) - F.pp_print_string fmt "Datatype:"; - F.pp_print_cut fmt ()) - -(** See {!start_fun_decl_group}. *) -let end_type_decl_group (fmt : F.formatter) (is_rec : bool) - (dg : Pure.type_decl list) = - match !backend with - | FStar -> () - | Coq -> - (* For aesthetic reasons, we print the Coq end group delimiter directly - in {!extract_fun_decl}. *) - () - | Lean -> - (* We must add the "end" keyword to groups of mutually recursive functions *) - if is_rec && List.length dg > 1 then ( - F.pp_print_cut fmt (); - F.pp_print_string fmt "end"; - (* Add breaks to insert new lines between definitions *) - F.pp_print_break fmt 0 0) - else () - | HOL4 -> - (* In HOL4, opaque types and empty records have a special treatment *) - if - is_single_opaque_type_decl_group dg - || is_empty_record_type_decl_group dg - then () - else ( - (* Close the box for the definitions *) - F.pp_close_box fmt (); - (* Print the end delimiter *) - F.pp_print_cut fmt (); - F.pp_print_string fmt "End"; - (* Close the box for the delimiters *) - F.pp_close_box fmt (); - (* Add breaks to insert new lines between definitions *) - F.pp_print_break fmt 0 0) - -let unit_name () = - match !backend with Lean -> "Unit" | Coq | FStar | HOL4 -> "unit" - -(** Small helper *) -let extract_arrow (fmt : F.formatter) () : unit = - if !Config.backend = Lean then F.pp_print_string fmt "→" - else F.pp_print_string fmt "->" - -let extract_const_generic (ctx : extraction_ctx) (fmt : F.formatter) - (inside : bool) (cg : const_generic) : unit = - match cg with - | ConstGenericGlobal id -> - let s = ctx_get_global ctx.use_opaque_pre id ctx in - F.pp_print_string fmt s - | ConstGenericValue v -> ctx.fmt.extract_literal fmt inside v - | ConstGenericVar id -> - let s = ctx_get_const_generic_var id ctx in - F.pp_print_string fmt s - -let extract_literal_type (ctx : extraction_ctx) (fmt : F.formatter) - (ty : literal_type) : unit = - match ty with - | Bool -> F.pp_print_string fmt ctx.fmt.bool_name - | Char -> F.pp_print_string fmt ctx.fmt.char_name - | Integer int_ty -> F.pp_print_string fmt (ctx.fmt.int_name int_ty) - -(** [inside] constrols whether we should add parentheses or not around type - applications (if [true] we add parentheses). - - [no_params_tys]: for all the types inside this set, do not print the type parameters. - This is used for HOL4. As polymorphism is uniform in HOL4, printing the - type parameters in the recursive definitions is useless (and actually - forbidden). - - For instance, where in F* we would write: - {[ - type list a = | Nil : list a | Cons : a -> list a -> list a - ]} - - In HOL4 we would simply write: - {[ - Datatype: - list = Nil 'a | Cons 'a list - End - ]} - *) -let rec extract_ty (ctx : extraction_ctx) (fmt : F.formatter) - (no_params_tys : TypeDeclId.Set.t) (inside : bool) (ty : ty) : unit = - let extract_rec = extract_ty ctx fmt no_params_tys in - match ty with - | Adt (type_id, tys, cgs) -> ( - let has_params = tys <> [] || cgs <> [] in - match type_id with - | Tuple -> - (* This is a bit annoying, but in F*/Coq/HOL4 [()] is not the unit type: - * we have to write [unit]... *) - if tys = [] then F.pp_print_string fmt (unit_name ()) - else ( - F.pp_print_string fmt "("; - Collections.List.iter_link - (fun () -> - F.pp_print_space fmt (); - let product = - match !backend with - | FStar -> "&" - | Coq -> "*" - | Lean -> "×" - | HOL4 -> "#" - in - F.pp_print_string fmt product; - F.pp_print_space fmt ()) - (extract_rec true) tys; - F.pp_print_string fmt ")") - | AdtId _ | Assumed _ -> ( - (* HOL4 behaves differently. Where in Coq/FStar/Lean we would write: - `tree a b` - - In HOL4 we would write: - `('a, 'b) tree` - *) - let with_opaque_pre = false in - match !backend with - | FStar | Coq | Lean -> - let print_paren = inside && has_params in - if print_paren then F.pp_print_string fmt "("; - (* TODO: for now, only the opaque *functions* are extracted in the - opaque module. The opaque *types* are assumed. *) - F.pp_print_string fmt (ctx_get_type with_opaque_pre type_id ctx); - if tys <> [] then ( - F.pp_print_space fmt (); - Collections.List.iter_link (F.pp_print_space fmt) - (extract_rec true) tys); - if cgs <> [] then ( - F.pp_print_space fmt (); - Collections.List.iter_link (F.pp_print_space fmt) - (extract_const_generic ctx fmt true) - cgs); - if print_paren then F.pp_print_string fmt ")" - | HOL4 -> - (* Const generics are unsupported in HOL4 *) - assert (cgs = []); - let print_tys = - match type_id with - | AdtId id -> not (TypeDeclId.Set.mem id no_params_tys) - | Assumed _ -> true - | _ -> raise (Failure "Unreachable") - in - if tys <> [] && print_tys then ( - let print_paren = List.length tys > 1 in - if print_paren then F.pp_print_string fmt "("; - Collections.List.iter_link - (fun () -> - F.pp_print_string fmt ","; - F.pp_print_space fmt ()) - (extract_rec true) tys; - if print_paren then F.pp_print_string fmt ")"; - F.pp_print_space fmt ()); - F.pp_print_string fmt (ctx_get_type with_opaque_pre type_id ctx))) - | TypeVar vid -> F.pp_print_string fmt (ctx_get_type_var vid ctx) - | Literal lty -> extract_literal_type ctx fmt lty - | Arrow (arg_ty, ret_ty) -> - if inside then F.pp_print_string fmt "("; - extract_rec false arg_ty; - F.pp_print_space fmt (); - extract_arrow fmt (); - F.pp_print_space fmt (); - extract_rec false ret_ty; - if inside then F.pp_print_string fmt ")" - -(** Compute the names for all the top-level identifiers used in a type - definition (type name, variant names, field names, etc. but not type - parameters). - - We need to do this preemptively, beforce extracting any definition, - because of recursive definitions. - *) -let extract_type_decl_register_names (ctx : extraction_ctx) (def : type_decl) : - extraction_ctx = - (* Compute and register the type def name *) - let ctx = ctx_add_type_decl def ctx in - (* Compute and register: - * - the variant names, if this is an enumeration - * - the field names, if this is a structure - *) - let ctx = - match def.kind with - | Struct fields -> - (* Add the fields *) - let ctx = - fst - (ctx_add_fields def (FieldId.mapi (fun id f -> (id, f)) fields) ctx) - in - (* Add the constructor name *) - fst (ctx_add_struct def ctx) - | Enum variants -> - fst - (ctx_add_variants def - (VariantId.mapi (fun id v -> (id, v)) variants) - ctx) - | Opaque -> - (* Nothing to do *) - ctx - in - (* Return *) - ctx - -(** Print the variants *) -let extract_type_decl_variant (ctx : extraction_ctx) (fmt : F.formatter) - (type_decl_group : TypeDeclId.Set.t) (type_name : string) - (type_params : string list) (cg_params : string list) (cons_name : string) - (fields : field list) : unit = - F.pp_print_space fmt (); - (* variant box *) - F.pp_open_hvbox fmt ctx.indent_incr; - (* [| Cons :] - * Note that we really don't want any break above so we print everything - * at once. *) - let opt_colon = if !backend <> HOL4 then " :" else "" in - F.pp_print_string fmt ("| " ^ cons_name ^ opt_colon); - let print_field (fid : FieldId.id) (f : field) (ctx : extraction_ctx) : - extraction_ctx = - F.pp_print_space fmt (); - (* Open the field box *) - F.pp_open_box fmt ctx.indent_incr; - (* Print the field names, if the backend accepts it. - * [ x :] - * Note that when printing fields, we register the field names as - * *variables*: they don't need to be unique at the top level. *) - let ctx = - match !backend with - | FStar -> ( - match f.field_name with - | None -> ctx - | Some field_name -> - let var_id = VarId.of_int (FieldId.to_int fid) in - let field_name = - ctx.fmt.var_basename ctx.names_map.names_set (Some field_name) - f.field_ty - in - let ctx, field_name = ctx_add_var field_name var_id ctx in - F.pp_print_string fmt (field_name ^ " :"); - F.pp_print_space fmt (); - ctx) - | Coq | Lean | HOL4 -> ctx - in - (* Print the field type *) - let inside = !backend = HOL4 in - extract_ty ctx fmt type_decl_group inside f.field_ty; - (* Print the arrow [->] *) - if !backend <> HOL4 then ( - F.pp_print_space fmt (); - extract_arrow fmt ()); - (* Close the field box *) - F.pp_close_box fmt (); - (* Return *) - ctx - in - (* Print the fields *) - let fields = FieldId.mapi (fun fid f -> (fid, f)) fields in - let _ = - List.fold_left (fun ctx (fid, f) -> print_field fid f ctx) ctx fields - in - (* Sanity check: HOL4 doesn't support const generics *) - assert (cg_params = [] || !backend <> HOL4); - (* Print the final type *) - if !backend <> HOL4 then ( - F.pp_print_space fmt (); - F.pp_open_hovbox fmt 0; - F.pp_print_string fmt type_name; - List.iter - (fun p -> - F.pp_print_space fmt (); - F.pp_print_string fmt p) - (List.append type_params cg_params); - F.pp_close_box fmt ()); - (* Close the variant box *) - F.pp_close_box fmt () - -(* TODO: we don' need the [def_name] paramter: it can be retrieved from the context *) -let extract_type_decl_enum_body (ctx : extraction_ctx) (fmt : F.formatter) - (type_decl_group : TypeDeclId.Set.t) (def : type_decl) (def_name : string) - (type_params : string list) (cg_params : string list) - (variants : variant list) : unit = - (* We want to generate a definition which looks like this (taking F* as example): - {[ - type list a = | Cons : a -> list a -> list a | Nil : list a - ]} - - If there isn't enough space on one line: - {[ - type s = - | Cons : a -> list a -> list a - | Nil : list a - ]} - - And if we need to write the type of a variant on several lines: - {[ - type s = - | Cons : - a -> - list a -> - list a - | Nil : list a - ]} - - Finally, it is possible to give names to the variant fields in Rust. - In this situation, we generate a definition like this: - {[ - type s = - | Cons : hd:a -> tl:list a -> list a - | Nil : list a - ]} - - Note that we already printed: [type s =] - *) - let print_variant _variant_id (v : variant) = - (* We don't lookup the name, because it may have a prefix for the type - id (in the case of Lean) *) - let cons_name = ctx.fmt.variant_name def.name v.variant_name in - let fields = v.fields in - extract_type_decl_variant ctx fmt type_decl_group def_name type_params - cg_params cons_name fields - in - (* Print the variants *) - let variants = VariantId.mapi (fun vid v -> (vid, v)) variants in - List.iter (fun (vid, v) -> print_variant vid v) variants - -let extract_type_decl_struct_body (ctx : extraction_ctx) (fmt : F.formatter) - (type_decl_group : TypeDeclId.Set.t) (kind : decl_kind) (def : type_decl) - (type_params : string list) (cg_params : string list) (fields : field list) - : unit = - (* We want to generate a definition which looks like this (taking F* as example): - {[ - type t = { x : int; y : bool; } - ]} - - If there isn't enough space on one line: - {[ - type t = - { - x : int; y : bool; - } - ]} - - And if there is even less space: - {[ - type t = - { - x : int; - y : bool; - } - ]} - - Also, in case there are no fields, we need to define the type as [unit] - ([type t = {}] doesn't work in F* ). - - Coq: - ==== - We need to define the constructor name upon defining the struct (record, in Coq). - The syntex is: - {[ - Record Foo = mkFoo { x : int; y : bool; }. - }] - - Also, Coq doesn't support groups of mutually recursive inductives and records. - This is fine, because we can then define records as inductives, and leverage - the fact that when record fields are accessed, the records are symbolically - expanded which introduces let bindings of the form: [let RecordCons ... = x in ...]. - As a consequence, we never use the record projectors (unless we reconstruct - them in the micro passes of course). - - HOL4: - ===== - Type definitions are written as follows: - {[ - Datatype: - tree = - TLeaf 'a - | TNode node ; - - node = - Node (tree list) - End - ]} - *) - (* Note that we already printed: [type t =] *) - let is_rec = decl_is_from_rec_group kind in - let _ = - if !backend = FStar && fields = [] then ( - F.pp_print_space fmt (); - F.pp_print_string fmt (unit_name ())) - else if !backend = Lean && fields = [] then () - (* If the definition is recursive, we may need to extract it as an inductive - (instead of a record). We start with the "normal" case: we extract it - as a record. *) - else if (not is_rec) || (!backend <> Coq && !backend <> Lean) then ( - if !backend <> Lean then F.pp_print_space fmt (); - (* If Coq: print the constructor name *) - (* TODO: remove superfluous test not is_rec below *) - if !backend = Coq && not is_rec then ( - let with_opaque_pre = false in - F.pp_print_string fmt - (ctx_get_struct with_opaque_pre (AdtId def.def_id) ctx); - F.pp_print_string fmt " "); - (match !backend with - | Lean -> () - | FStar | Coq -> F.pp_print_string fmt "{" - | HOL4 -> F.pp_print_string fmt "<|"); - F.pp_print_break fmt 1 ctx.indent_incr; - (* The body itself *) - (* Open a box for the body *) - (match !backend with - | Coq | FStar | HOL4 -> F.pp_open_hvbox fmt 0 - | Lean -> F.pp_open_vbox fmt 0); - (* Print the fields *) - let print_field (field_id : FieldId.id) (f : field) : unit = - let field_name = ctx_get_field (AdtId def.def_id) field_id ctx in - (* Open a box for the field *) - F.pp_open_box fmt ctx.indent_incr; - F.pp_print_string fmt field_name; - F.pp_print_space fmt (); - F.pp_print_string fmt ":"; - F.pp_print_space fmt (); - extract_ty ctx fmt type_decl_group false f.field_ty; - if !backend <> Lean then F.pp_print_string fmt ";"; - (* Close the box for the field *) - F.pp_close_box fmt () - in - let fields = FieldId.mapi (fun fid f -> (fid, f)) fields in - Collections.List.iter_link (F.pp_print_space fmt) - (fun (fid, f) -> print_field fid f) - fields; - (* Close the box for the body *) - F.pp_close_box fmt (); - match !backend with - | Lean -> () - | FStar | Coq -> - F.pp_print_space fmt (); - F.pp_print_string fmt "}" - | HOL4 -> - F.pp_print_space fmt (); - F.pp_print_string fmt "|>") - else ( - (* We extract for Coq or Lean, and we have a recursive record, or a record in - a group of mutually recursive types: we extract it as an inductive type *) - assert (is_rec && (!backend = Coq || !backend = Lean)); - let with_opaque_pre = false in - (* Small trick: in Lean we use namespaces, meaning we don't need to prefix - the constructor name with the name of the type at definition site, - i.e., instead of generating `inductive Foo := | MkFoo ...` like in Coq - we generate `inductive Foo := | mk ... *) - let cons_name = - if !backend = Lean then "mk" - else ctx_get_struct with_opaque_pre (AdtId def.def_id) ctx - in - let def_name = ctx_get_local_type with_opaque_pre def.def_id ctx in - extract_type_decl_variant ctx fmt type_decl_group def_name type_params - cg_params cons_name fields) - in - () - -(** Extract a nestable, muti-line comment *) -let extract_comment (fmt : F.formatter) (sl : string list) : unit = - (* Delimiters, space after we break a line *) - let ld, space, rd = - match !backend with - | Coq | FStar | HOL4 -> ("(** ", 4, " *)") - | Lean -> ("/- ", 3, " -/") - in - F.pp_open_vbox fmt space; - F.pp_print_string fmt ld; - (match sl with - | [] -> () - | s :: sl -> - F.pp_print_string fmt s; - List.iter - (fun s -> - F.pp_print_space fmt (); - F.pp_print_string fmt s) - sl); - F.pp_print_string fmt rd; - F.pp_close_box fmt () - -(** Extract a type declaration. - - This function is for all type declarations and all backends **at the exception** - of opaque (assumed/declared) types format4 HOL4. - - See {!extract_type_decl}. - *) -let extract_type_decl_gen (ctx : extraction_ctx) (fmt : F.formatter) - (type_decl_group : TypeDeclId.Set.t) (kind : decl_kind) (def : type_decl) - (extract_body : bool) : unit = - (* Sanity check *) - assert (extract_body || !backend <> HOL4); - let type_kind = - if extract_body then - match def.kind with - | Struct _ -> Some Struct - | Enum _ -> Some Enum - | Opaque -> None - else None - in - (* If in Coq and the declaration is opaque, it must have the shape: - [Axiom Ident : forall (T0 ... Tn : Type) (N0 : ...) ... (Nn : ...), ... -> ... -> ...]. - - The boolean [is_opaque_coq] is used to detect this case. - *) - let is_opaque = type_kind = None in - let is_opaque_coq = !backend = Coq && is_opaque in - let use_forall = - is_opaque_coq && (def.type_params <> [] || def.const_generic_params <> []) - in - (* Retrieve the definition name *) - let with_opaque_pre = false in - let def_name = ctx_get_local_type with_opaque_pre def.def_id ctx in - (* Add the type and const generic params - note that we need those bindings only for the - * body translation (they are not top-level) *) - let ctx_body, type_params, cg_params = - ctx_add_type_const_generic_params def.type_params def.const_generic_params - ctx - in - let ty_cg_params = List.append type_params cg_params in - (* Add a break before *) - if !backend <> HOL4 || not (decl_is_first_from_group kind) then - F.pp_print_break fmt 0 0; - (* Print a comment to link the extracted type to its original rust definition *) - extract_comment fmt [ "[" ^ Print.name_to_string def.name ^ "]" ]; - F.pp_print_break fmt 0 0; - (* Open a box for the definition, so that whenever possible it gets printed on - * one line. Note however that in the case of Lean line breaks are important - * for parsing: we thus use a hovbox. *) - (match !backend with - | Coq | FStar | HOL4 -> F.pp_open_hvbox fmt 0 - | Lean -> F.pp_open_vbox fmt 0); - (* Open a box for "type TYPE_NAME (TYPE_PARAMS CONST_GEN_PARAMS) =" *) - F.pp_open_hovbox fmt ctx.indent_incr; - (* > "type TYPE_NAME" *) - let qualif = ctx.fmt.type_decl_kind_to_qualif kind type_kind in - (match qualif with - | Some qualif -> F.pp_print_string fmt (qualif ^ " " ^ def_name) - | None -> F.pp_print_string fmt def_name); - (* HOL4 doesn't support const generics *) - assert (cg_params = [] || !backend <> HOL4); - (* Print the type/const generic parameters *) - if ty_cg_params <> [] && !backend <> HOL4 then ( - if use_forall then ( - F.pp_print_space fmt (); - F.pp_print_string fmt ":"; - F.pp_print_space fmt (); - F.pp_print_string fmt "forall"); - (* Print the type parameters *) - if type_params <> [] then ( - F.pp_print_space fmt (); - F.pp_print_string fmt "("; - List.iter - (fun s -> - F.pp_print_string fmt s; - F.pp_print_space fmt ()) - type_params; - F.pp_print_string fmt ":"; - F.pp_print_space fmt (); - F.pp_print_string fmt (type_keyword () ^ ")")); - (* Print the const generic parameters *) - List.iter - (fun (var : const_generic_var) -> - F.pp_print_space fmt (); - F.pp_print_string fmt "("; - let n = ctx_get_const_generic_var var.index ctx in - F.pp_print_string fmt n; - F.pp_print_space fmt (); - F.pp_print_string fmt ":"; - F.pp_print_space fmt (); - extract_literal_type ctx fmt var.ty; - F.pp_print_string fmt ")") - def.const_generic_params); - (* Print the "=" if we extract the body*) - if extract_body then ( - F.pp_print_space fmt (); - let eq = - match !backend with - | FStar -> "=" - | Coq -> ":=" - | Lean -> - if type_kind = Some Struct && kind = SingleNonRec then "where" - else ":=" - | HOL4 -> "=" - in - F.pp_print_string fmt eq) - else ( - (* Otherwise print ": Type", unless it is the HOL4 backend (in - which case we declare the type with `new_type`) *) - if use_forall then F.pp_print_string fmt "," - else ( - F.pp_print_space fmt (); - F.pp_print_string fmt ":"); - F.pp_print_space fmt (); - F.pp_print_string fmt (type_keyword ())); - (* Close the box for "type TYPE_NAME (TYPE_PARAMS) =" *) - F.pp_close_box fmt (); - (if extract_body then - match def.kind with - | Struct fields -> - extract_type_decl_struct_body ctx_body fmt type_decl_group kind def - type_params cg_params fields - | Enum variants -> - extract_type_decl_enum_body ctx_body fmt type_decl_group def def_name - type_params cg_params variants - | Opaque -> raise (Failure "Unreachable")); - (* Add the definition end delimiter *) - if !backend = HOL4 && decl_is_not_last_from_group kind then ( - F.pp_print_space fmt (); - F.pp_print_string fmt ";") - else if !backend = Coq && decl_is_last_from_group kind then ( - (* This is actually an end of group delimiter. For aesthetic reasons - we print it here instead of in {!end_type_decl_group}. *) - F.pp_print_cut fmt (); - F.pp_print_string fmt "."); - (* Close the box for the definition *) - F.pp_close_box fmt (); - (* Add breaks to insert new lines between definitions *) - if !backend <> HOL4 || decl_is_not_last_from_group kind then - F.pp_print_break fmt 0 0 - -(** Extract an opaque type declaration to HOL4. - - Remark (SH): having to treat this specific case separately is very annoying, - but I could not find a better way. - *) -let extract_type_decl_hol4_opaque (ctx : extraction_ctx) (fmt : F.formatter) - (def : type_decl) : unit = - (* Retrieve the definition name *) - let with_opaque_pre = false in - let def_name = ctx_get_local_type with_opaque_pre def.def_id ctx in - (* Generic parameters are unsupported *) - assert (def.const_generic_params = []); - (* Count the number of parameters *) - let num_params = List.length def.type_params in - (* Generate the declaration *) - F.pp_print_space fmt (); - F.pp_print_string fmt - ("val _ = new_type (\"" ^ def_name ^ "\", " ^ string_of_int num_params ^ ")"); - F.pp_print_space fmt () - -(** Extract an empty record type declaration to HOL4. - - Empty records are not supported in HOL4, so we extract them as type - abbreviations to the unit type. - - Remark (SH): having to treat this specific case separately is very annoying, - but I could not find a better way. - *) -let extract_type_decl_hol4_empty_record (ctx : extraction_ctx) - (fmt : F.formatter) (def : type_decl) : unit = - (* Retrieve the definition name *) - let with_opaque_pre = false in - let def_name = ctx_get_local_type with_opaque_pre def.def_id ctx in - (* Sanity check *) - assert (def.type_params = []); - assert (def.const_generic_params = []); - (* Generate the declaration *) - F.pp_print_space fmt (); - F.pp_print_string fmt ("Type " ^ def_name ^ " = “: unit”"); - F.pp_print_space fmt () - -(** Extract a type declaration. - - Note that all the names used for extraction should already have been - registered. - - This function should be inserted between calls to {!start_type_decl_group} - and {!end_type_decl_group}. - *) -let extract_type_decl (ctx : extraction_ctx) (fmt : F.formatter) - (type_decl_group : TypeDeclId.Set.t) (kind : decl_kind) (def : type_decl) : - unit = - let extract_body = - match kind with - | SingleNonRec | SingleRec | MutRecFirst | MutRecInner | MutRecLast -> true - | Assumed | Declared -> false - in - if extract_body then - if !backend = HOL4 && is_empty_record_type_decl def then - extract_type_decl_hol4_empty_record ctx fmt def - else extract_type_decl_gen ctx fmt type_decl_group kind def extract_body - else - match !backend with - | FStar | Coq | Lean -> - extract_type_decl_gen ctx fmt type_decl_group kind def extract_body - | HOL4 -> extract_type_decl_hol4_opaque ctx fmt def - -(** Auxiliary function. - - Generate [Arguments] instructions in Coq. - *) -let extract_type_decl_coq_arguments (ctx : extraction_ctx) (fmt : F.formatter) - (kind : decl_kind) (decl : type_decl) : unit = - assert (!backend = Coq); - (* Generating the [Arguments] instructions is useful only if there are type parameters *) - if decl.type_params = [] && decl.const_generic_params = [] then () - else - (* Add the type params - note that we need those bindings only for the - * body translation (they are not top-level) *) - let _ctx_body, type_params, cg_params = - ctx_add_type_const_generic_params decl.type_params - decl.const_generic_params ctx - in - (* Auxiliary function to extract an [Arguments Cons {T} _ _.] instruction *) - let extract_arguments_info (cons_name : string) (fields : 'a list) : unit = - (* Add a break before *) - F.pp_print_break fmt 0 0; - (* Open a box *) - F.pp_open_hovbox fmt ctx.indent_incr; - (* Small utility *) - let print_vars () = - List.iter - (fun (var : string) -> - F.pp_print_space fmt (); - F.pp_print_string fmt ("{" ^ var ^ "}")) - (List.append type_params cg_params) - in - let print_fields () = - List.iter - (fun _ -> - F.pp_print_space fmt (); - F.pp_print_string fmt "_") - fields - in - F.pp_print_break fmt 0 0; - F.pp_print_string fmt "Arguments"; - F.pp_print_space fmt (); - F.pp_print_string fmt cons_name; - print_vars (); - print_fields (); - F.pp_print_string fmt "."; - - (* Close the box *) - F.pp_close_box fmt () - in - - (* Generate the [Arguments] instruction *) - match decl.kind with - | Opaque -> () - | Struct fields -> - let adt_id = AdtId decl.def_id in - (* Generate the instruction for the record constructor *) - let with_opaque_pre = false in - let cons_name = ctx_get_struct with_opaque_pre adt_id ctx in - extract_arguments_info cons_name fields; - (* Generate the instruction for the record projectors, if there are *) - let is_rec = decl_is_from_rec_group kind in - if not is_rec then - FieldId.iteri - (fun fid _ -> - let cons_name = ctx_get_field adt_id fid ctx in - extract_arguments_info cons_name []) - fields; - (* Add breaks to insert new lines between definitions *) - F.pp_print_break fmt 0 0 - | Enum variants -> - (* Generate the instructions *) - VariantId.iteri - (fun vid (v : variant) -> - let cons_name = ctx_get_variant (AdtId decl.def_id) vid ctx in - extract_arguments_info cons_name v.fields) - variants; - (* Add breaks to insert new lines between definitions *) - F.pp_print_break fmt 0 0 - -(** Auxiliary function. - - Generate field projectors in Coq. - - Sometimes we extract records as inductives in Coq: when this happens we - have to define the field projectors afterwards. - *) -let extract_type_decl_record_field_projectors (ctx : extraction_ctx) - (fmt : F.formatter) (kind : decl_kind) (decl : type_decl) : unit = - assert (!backend = Coq); - match decl.kind with - | Opaque | Enum _ -> () - | Struct fields -> - (* Records are extracted as inductives only if they are recursive *) - let is_rec = decl_is_from_rec_group kind in - if is_rec then - (* Add the type params *) - let ctx, type_params, cg_params = - ctx_add_type_const_generic_params decl.type_params - decl.const_generic_params ctx - in - let ctx, record_var = ctx_add_var "x" (VarId.of_int 0) ctx in - let ctx, field_var = ctx_add_var "x" (VarId.of_int 1) ctx in - let with_opaque_pre = false in - let def_name = ctx_get_local_type with_opaque_pre decl.def_id ctx in - let cons_name = - ctx_get_struct with_opaque_pre (AdtId decl.def_id) ctx - in - let extract_field_proj (field_id : FieldId.id) (_ : field) : unit = - F.pp_print_space fmt (); - (* Outer box for the projector definition *) - F.pp_open_hvbox fmt 0; - (* Inner box for the projector definition *) - F.pp_open_hvbox fmt ctx.indent_incr; - (* Open a box for the [Definition PROJ ... :=] *) - F.pp_open_hovbox fmt ctx.indent_incr; - F.pp_print_string fmt "Definition"; - F.pp_print_space fmt (); - let field_name = ctx_get_field (AdtId decl.def_id) field_id ctx in - F.pp_print_string fmt field_name; - F.pp_print_space fmt (); - (* Print the type parameters *) - if type_params <> [] then ( - F.pp_print_string fmt "{"; - List.iter - (fun p -> - F.pp_print_string fmt p; - F.pp_print_space fmt ()) - type_params; - F.pp_print_string fmt ":"; - F.pp_print_space fmt (); - F.pp_print_string fmt "Type}"; - F.pp_print_space fmt ()); - (* Print the const generic parameters *) - if cg_params <> [] then - List.iter - (fun (v : const_generic_var) -> - F.pp_print_string fmt "{"; - let n = ctx_get_const_generic_var v.index ctx in - F.pp_print_string fmt n; - F.pp_print_string fmt ":"; - F.pp_print_space fmt (); - extract_literal_type ctx fmt v.ty; - F.pp_print_string fmt "}"; - F.pp_print_space fmt ()) - decl.const_generic_params; - (* Print the record parameter *) - F.pp_print_string fmt "("; - F.pp_print_string fmt record_var; - F.pp_print_space fmt (); - F.pp_print_string fmt ":"; - F.pp_print_space fmt (); - F.pp_print_string fmt def_name; - List.iter - (fun p -> - F.pp_print_space fmt (); - F.pp_print_string fmt p) - type_params; - F.pp_print_string fmt ")"; - (* *) - F.pp_print_space fmt (); - F.pp_print_string fmt ":="; - (* Close the box for the [Definition PROJ ... :=] *) - F.pp_close_box fmt (); - F.pp_print_space fmt (); - (* Open a box for the whole match *) - F.pp_open_hvbox fmt 0; - (* Open a box for the [match ... with] *) - F.pp_open_hovbox fmt ctx.indent_incr; - F.pp_print_string fmt "match"; - F.pp_print_space fmt (); - F.pp_print_string fmt record_var; - F.pp_print_space fmt (); - F.pp_print_string fmt "with"; - (* Close the box for the [match ... with] *) - F.pp_close_box fmt (); - - (* Open a box for the branch *) - F.pp_open_hovbox fmt ctx.indent_incr; - (* Print the match branch *) - F.pp_print_space fmt (); - F.pp_print_string fmt "|"; - F.pp_print_space fmt (); - F.pp_print_string fmt cons_name; - FieldId.iteri - (fun id _ -> - F.pp_print_space fmt (); - if field_id = id then F.pp_print_string fmt field_var - else F.pp_print_string fmt "_") - fields; - F.pp_print_space fmt (); - F.pp_print_string fmt "=>"; - F.pp_print_space fmt (); - F.pp_print_string fmt field_var; - (* Close the box for the branch *) - F.pp_close_box fmt (); - (* Print the [end] *) - F.pp_print_space fmt (); - F.pp_print_string fmt "end"; - (* Close the box for the whole match *) - F.pp_close_box fmt (); - (* Close the inner box projector *) - F.pp_close_box fmt (); - (* If Coq: end the definition with a "." *) - if !backend = Coq then ( - F.pp_print_cut fmt (); - F.pp_print_string fmt "."); - (* Close the outer box projector *) - F.pp_close_box fmt (); - (* Add breaks to insert new lines between definitions *) - F.pp_print_break fmt 0 0 - in - - let extract_proj_notation (field_id : FieldId.id) (_ : field) : unit = - F.pp_print_space fmt (); - (* Outer box for the projector definition *) - F.pp_open_hvbox fmt 0; - (* Inner box for the projector definition *) - F.pp_open_hovbox fmt ctx.indent_incr; - let ctx, record_var = ctx_add_var "x" (VarId.of_int 0) ctx in - F.pp_print_string fmt "Notation"; - F.pp_print_space fmt (); - let field_name = ctx_get_field (AdtId decl.def_id) field_id ctx in - F.pp_print_string fmt ("\"" ^ record_var ^ " .(" ^ field_name ^ ")\""); - F.pp_print_space fmt (); - F.pp_print_string fmt ":="; - F.pp_print_space fmt (); - F.pp_print_string fmt "("; - F.pp_print_string fmt field_name; - F.pp_print_space fmt (); - F.pp_print_string fmt record_var; - F.pp_print_string fmt ")"; - F.pp_print_space fmt (); - F.pp_print_string fmt "(at level 9)"; - (* Close the inner box projector *) - F.pp_close_box fmt (); - (* If Coq: end the definition with a "." *) - if !backend = Coq then ( - F.pp_print_cut fmt (); - F.pp_print_string fmt "."); - (* Close the outer box projector *) - F.pp_close_box fmt (); - (* Add breaks to insert new lines between definitions *) - F.pp_print_break fmt 0 0 - in - - let extract_field_proj_and_notation (field_id : FieldId.id) - (field : field) : unit = - extract_field_proj field_id field; - extract_proj_notation field_id field - in - - FieldId.iteri extract_field_proj_and_notation fields - -(** Extract extra information for a type (e.g., [Arguments] instructions in Coq). - - Note that all the names used for extraction should already have been - registered. - *) -let extract_type_decl_extra_info (ctx : extraction_ctx) (fmt : F.formatter) - (kind : decl_kind) (decl : type_decl) : unit = - match !backend with - | FStar | Lean | HOL4 -> () - | Coq -> - extract_type_decl_coq_arguments ctx fmt kind decl; - extract_type_decl_record_field_projectors ctx fmt kind decl - -(** Extract the state type declaration. *) -let extract_state_type (fmt : F.formatter) (ctx : extraction_ctx) - (kind : decl_kind) : unit = - (* Add a break before *) - F.pp_print_break fmt 0 0; - (* Print a comment *) - extract_comment fmt [ "The state type used in the state-error monad" ]; - F.pp_print_break fmt 0 0; - (* Open a box for the definition, so that whenever possible it gets printed on - * one line *) - F.pp_open_hvbox fmt 0; - (* Retrieve the name *) - let state_name = ctx_get_assumed_type State ctx in - (* The syntax for Lean and Coq is almost identical. *) - let print_axiom () = - let axiom = - match !backend with - | Coq -> "Axiom" - | Lean -> "axiom" - | FStar | HOL4 -> raise (Failure "Unexpected") - in - F.pp_print_string fmt axiom; - F.pp_print_space fmt (); - F.pp_print_string fmt state_name; - F.pp_print_space fmt (); - F.pp_print_string fmt ":"; - F.pp_print_space fmt (); - F.pp_print_string fmt "Type"; - if !backend = Coq then F.pp_print_string fmt "." - in - (* The kind should be [Assumed] or [Declared] *) - (match kind with - | SingleNonRec | SingleRec | MutRecFirst | MutRecInner | MutRecLast -> - raise (Failure "Unexpected") - | Assumed -> ( - match !backend with - | FStar -> - F.pp_print_string fmt "assume"; - F.pp_print_space fmt (); - F.pp_print_string fmt "type"; - F.pp_print_space fmt (); - F.pp_print_string fmt state_name; - F.pp_print_space fmt (); - F.pp_print_string fmt ":"; - F.pp_print_space fmt (); - F.pp_print_string fmt "Type0" - | HOL4 -> - F.pp_print_string fmt ("val _ = new_type (\"" ^ state_name ^ "\", 0)") - | Coq | Lean -> print_axiom ()) - | Declared -> ( - match !backend with - | FStar -> - F.pp_print_string fmt "val"; - F.pp_print_space fmt (); - F.pp_print_string fmt state_name; - F.pp_print_space fmt (); - F.pp_print_string fmt ":"; - F.pp_print_space fmt (); - F.pp_print_string fmt "Type0" - | HOL4 -> - F.pp_print_string fmt ("val _ = new_type (\"" ^ state_name ^ "\", 0)") - | Coq | Lean -> print_axiom ())); - (* Close the box for the definition *) - F.pp_close_box fmt (); - (* Add breaks to insert new lines between definitions *) - F.pp_print_break fmt 0 0 +include ExtractTypes (** Compute the names for all the pure functions generated from a rust function (forward function and backward functions). *) -let extract_fun_decl_register_names (ctx : extraction_ctx) (keep_fwd : bool) +let extract_fun_decl_register_names (ctx : extraction_ctx) (has_decreases_clause : fun_decl -> bool) (def : pure_fun_translation) : extraction_ctx = - let (fwd, loop_fwds), back_ls = def in - (* Register the decrease clauses, if necessary *) - let register_decreases ctx def = - if has_decreases_clause def then - (* Add the termination measure *) - let ctx = ctx_add_termination_measure def ctx in - (* Add the decreases proof for Lean only *) - match !Config.backend with - | Coq | FStar -> ctx - | HOL4 -> raise (Failure "Unexpected") - | Lean -> ctx_add_decreases_proof def ctx - else ctx - in - let ctx = List.fold_left register_decreases ctx (fwd :: loop_fwds) in - let register_fun ctx f = ctx_add_fun_decl (keep_fwd, def) f ctx in - let register_funs ctx fl = List.fold_left register_fun ctx fl in - (* Register the forward functions' names *) - let ctx = register_funs ctx (fwd :: loop_fwds) in - (* Register the backward functions' names *) - let ctx = - List.fold_left - (fun ctx (back, loop_backs) -> - let ctx = register_fun ctx back in - register_funs ctx loop_backs) - ctx back_ls - in - - (* Return *) - ctx + (* Ignore the trait methods **declarations** (rem.: we do not ignore the trait + method implementations): we do not need to refer to them directly. We will + only use their type for the fields of the records we generate for the trait + declarations *) + match def.fwd.f.kind with + | TraitMethodDecl _ -> ctx + | _ -> ( + (* Check if the function is builtin *) + let builtin = + let open ExtractBuiltin in + let funs_map = builtin_funs_map () in + let sname = name_to_simple_name def.fwd.f.basename in + SimpleNameMap.find_opt sname funs_map + in + (* Use the builtin names if necessary *) + match builtin with + | Some (filter_info, info) -> + (* Register the filtering information, if there is *) + let ctx = + match filter_info with + | Some keep -> + { + ctx with + funs_filter_type_args_map = + FunDeclId.Map.add def.fwd.f.def_id keep + ctx.funs_filter_type_args_map; + } + | _ -> ctx + in + let backs = List.map (fun f -> f.f) def.backs in + let funs = if def.keep_fwd then def.fwd.f :: backs else backs in + List.fold_left + (fun ctx (f : fun_decl) -> + let open ExtractBuiltin in + let fun_id = + (Pure.FunId (Regular f.def_id), f.loop_id, f.back_id) + in + let fun_info = + List.find_opt + (fun (x : builtin_fun_info) -> x.rg = f.back_id) + info + in + match fun_info with + | Some fun_info -> + ctx_add (FunId (FromLlbc fun_id)) fun_info.extract_name ctx + | None -> + raise + (Failure + ("Not found: " + ^ Names.name_to_string f.basename + ^ ", " + ^ Print.option_to_string Pure.show_loop_id f.loop_id + ^ Print.option_to_string Pure.show_region_group_id + f.back_id))) + ctx funs + | None -> + let fwd = def.fwd in + let backs = def.backs in + (* Register the decrease clauses, if necessary *) + let register_decreases ctx def = + if has_decreases_clause def then + (* Add the termination measure *) + let ctx = ctx_add_termination_measure def ctx in + (* Add the decreases proof for Lean only *) + match !Config.backend with + | Coq | FStar -> ctx + | HOL4 -> raise (Failure "Unexpected") + | Lean -> ctx_add_decreases_proof def ctx + else ctx + in + let ctx = + List.fold_left register_decreases ctx (fwd.f :: fwd.loops) + in + let register_fun ctx f = ctx_add_fun_decl def f ctx in + let register_funs ctx fl = List.fold_left register_fun ctx fl in + (* Register the names of the forward functions *) + let ctx = + if def.keep_fwd then register_funs ctx (fwd.f :: fwd.loops) else ctx + in + (* Register the names of the backward functions *) + List.fold_left + (fun ctx { f = back; loops = loop_backs } -> + let ctx = register_fun ctx back in + register_funs ctx loop_backs) + ctx backs) (** Simply add the global name to the context. *) let extract_global_decl_register_names (ctx : extraction_ctx) @@ -2122,11 +124,11 @@ let extract_adt_g_value (inside : bool) (variant_id : VariantId.id option) (field_values : 'v list) (ty : ty) : extraction_ctx = match ty with - | Adt (Tuple, type_args, cg_args) -> + | Adt (Tuple, generics) -> (* Tuple *) (* For now, we only support fully applied tuple constructors *) - assert (List.length type_args = List.length field_values); - assert (cg_args = []); + assert (List.length generics.types = List.length field_values); + assert (generics.const_generics = [] && generics.trait_refs = []); (* This is very annoying: in Coq, we can't write [()] for the value of type [unit], we have to write [tt]. *) if !backend = Coq && field_values = [] then ( @@ -2144,7 +146,7 @@ let extract_adt_g_value in F.pp_print_string fmt ")"; ctx) - | Adt (adt_id, _, _) -> + | Adt (adt_id, _) -> (* "Regular" ADT *) (* If we are generating a pattern for a let-binding and we target Lean, @@ -2172,18 +174,14 @@ let extract_adt_g_value * [{ field0=...; ...; fieldn=...; }] in case of structures. *) let cons = - (* The ADT shouldn't be opaque *) - let with_opaque_pre = false in match variant_id with | Some vid -> ( (* In the case of Lean, we might have to add the type name as a prefix *) match (!backend, adt_id) with | Lean, Assumed _ -> - ctx_get_type with_opaque_pre adt_id ctx - ^ "." - ^ ctx_get_variant adt_id vid ctx + ctx_get_type adt_id ctx ^ "." ^ ctx_get_variant adt_id vid ctx | _ -> ctx_get_variant adt_id vid ctx) - | None -> ctx_get_struct with_opaque_pre adt_id ctx + | None -> ctx_get_struct adt_id ctx in let use_parentheses = inside && field_values <> [] in if use_parentheses then F.pp_print_string fmt "("; @@ -2202,8 +200,33 @@ let extract_adt_g_value (* Extract globals in the same way as variables *) let extract_global (ctx : extraction_ctx) (fmt : F.formatter) (id : A.GlobalDeclId.id) : unit = - let with_opaque_pre = ctx.use_opaque_pre in - F.pp_print_string fmt (ctx_get_global with_opaque_pre id ctx) + F.pp_print_string fmt (ctx_get_global id ctx) + +(* Filter the generics of a function if it is builtin *) +let fun_builtin_filter_types (id : FunDeclId.id) (types : 'a list) + (ctx : extraction_ctx) : ('a list, 'a list * string) Result.result = + match FunDeclId.Map.find_opt id ctx.funs_filter_type_args_map with + | None -> Result.Ok types + | Some filter -> + if List.length filter <> List.length types then ( + let decl = FunDeclId.Map.find id ctx.trans_funs in + let err = + "Ill-formed builtin information for function " + ^ Names.name_to_string decl.fwd.f.basename + ^ ": " + ^ string_of_int (List.length filter) + ^ " filtering arguments provided for " + ^ string_of_int (List.length types) + ^ " type arguments" + in + log#serror err; + Result.Error (types, err)) + else + let types = List.combine filter types in + let types = + List.filter_map (fun (b, ty) -> if b then Some ty else None) types + in + Result.Ok types (** [inside]: see {!extract_ty}. @@ -2218,7 +241,7 @@ let rec extract_typed_pattern (ctx : extraction_ctx) (fmt : F.formatter) ctx | PatVar (v, _) -> let vname = - ctx.fmt.var_basename ctx.names_map.names_set v.basename v.ty + ctx.fmt.var_basename ctx.names_maps.names_map.names_set v.basename v.ty in let ctx, vname = ctx_add_var vname v.id ctx in F.pp_print_string fmt vname; @@ -2249,6 +272,9 @@ let rec extract_texpression (ctx : extraction_ctx) (fmt : F.formatter) | Var var_id -> let var_name = ctx_get_var var_id ctx in F.pp_print_string fmt var_name + | CVar var_id -> + let var_name = ctx_get_const_generic_var var_id ctx in + F.pp_print_string fmt var_name | Const cv -> ctx.fmt.extract_literal fmt inside cv | App _ -> let app, args = destruct_apps e in @@ -2279,14 +305,26 @@ and extract_App (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool) (* Top-level qualifier *) match qualif.id with | FunOrOp fun_id -> - extract_function_call ctx fmt inside fun_id qualif.type_args - qualif.const_generic_args args + extract_function_call ctx fmt inside fun_id qualif.generics args | Global global_id -> extract_global ctx fmt global_id | AdtCons adt_cons_id -> - extract_adt_cons ctx fmt inside adt_cons_id qualif.type_args - qualif.const_generic_args args + extract_adt_cons ctx fmt inside adt_cons_id qualif.generics args | Proj proj -> - extract_field_projector ctx fmt inside app proj qualif.type_args args) + extract_field_projector ctx fmt inside app proj qualif.generics args + | TraitConst (trait_ref, generics, const_name) -> + let use_brackets = generics <> empty_generic_args in + if use_brackets then F.pp_print_string fmt "("; + extract_trait_ref ctx fmt TypeDeclId.Set.empty false trait_ref; + extract_generic_args ctx fmt TypeDeclId.Set.empty generics; + let name = + ctx_get_trait_const trait_ref.trait_decl_ref.trait_decl_id + const_name ctx + in + let add_brackets (s : string) = + if !backend = Coq then "(" ^ s ^ ")" else s + in + if use_brackets then F.pp_print_string fmt ")"; + F.pp_print_string fmt ("." ^ add_brackets name)) | _ -> (* "Regular" expression *) (* Open parentheses *) @@ -2309,8 +347,8 @@ and extract_App (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool) (** Subcase of the app case: function call *) and extract_function_call (ctx : extraction_ctx) (fmt : F.formatter) - (inside : bool) (fid : fun_or_op_id) (type_args : ty list) - (cg_args : const_generic list) (args : texpression list) : unit = + (inside : bool) (fid : fun_or_op_id) (generics : generic_args) + (args : texpression list) : unit = match (fid, args) with | Unop unop, [ arg ] -> (* A unop can have *at most* one argument (the result can't be a function!). @@ -2327,24 +365,124 @@ and extract_function_call (ctx : extraction_ctx) (fmt : F.formatter) if inside then F.pp_print_string fmt "("; (* Open a box for the function call *) F.pp_open_hovbox fmt ctx.indent_incr; - (* Print the function name *) - let with_opaque_pre = ctx.use_opaque_pre in - let fun_name = ctx_get_function with_opaque_pre fun_id ctx in - F.pp_print_string fmt fun_name; - (* Sanity check: HOL4 doesn't support const generics *) - assert (cg_args = [] || !backend <> HOL4); - (* Print the type parameters, if the backend is not HOL4 *) - if !backend <> HOL4 then ( - List.iter - (fun ty -> - F.pp_print_space fmt (); - extract_ty ctx fmt TypeDeclId.Set.empty true ty) - type_args; - List.iter - (fun cg -> + (* Print the function name. + + For the function name: the id is not the same depending on whether + we call a trait method and a "regular" function (remark: trait + method *implementations* are considered as regular functions here; + only calls to method of traits which are parameterized in a where + clause have a special treatment. + + Remark: the reason why trait method declarations have a special + treatment is that, as traits are extracted to records, we may + allow collisions between trait item names and some other names, + while we do not allow collisions between function names. + + # Impl trait refs: + ================== + When the trait ref refers to an impl, in + [InterpreterStatement.eval_transparent_function_call_symbolic] we + replace the call to the trait impl method to a call to the function + which implements the trait method (that is, we "forget" that we + called a trait method, and treat it as a regular function call). + + # Provided trait methods: + ========================= + Calls to provided trait methods also have a special treatment. + For now, we do not allow overriding provided trait methods (methods + for which a default implementation is provided in the trait declaration). + Whenever we translate a provided trait method, we translate it once as + a function which takes a trait ref as input. We have to handle this + case below. + + With an example, if in Rust we write: + {[ + fn Foo { + fn f(&self) -> u32; // Required + fn ret_true(&self) -> bool { true } // Provided + } + ]} + + We generate: + {[ + structure Foo (Self : Type) = { + f : Self -> result u32 + } + + let ret_true (Self : Type) (self_clause : Foo Self) (self : Self) : result bool = + true + ]} + *) + (match fun_id with + | FromLlbc + (TraitMethod (trait_ref, method_name, _fun_decl_id), lp_id, rg_id) -> + (* We have to check whether the trait method is required or provided *) + let trait_decl_id = trait_ref.trait_decl_ref.trait_decl_id in + let trait_decl = + TraitDeclId.Map.find trait_decl_id ctx.trans_trait_decls + in + let method_id = + PureUtils.trait_decl_get_method trait_decl method_name + in + + if not method_id.is_provided then ( + (* Required method *) + assert (lp_id = None); + extract_trait_ref ctx fmt TypeDeclId.Set.empty true trait_ref; + let fun_name = + ctx_get_trait_method trait_ref.trait_decl_ref.trait_decl_id + method_name rg_id ctx + in + let add_brackets (s : string) = + if !backend = Coq then "(" ^ s ^ ")" else s + in + F.pp_print_string fmt ("." ^ add_brackets fun_name)) + else + (* Provided method: we see it as a regular function call, and use + the function name *) + let fun_id = + FromLlbc (FunId (Regular method_id.id), lp_id, rg_id) + in + let fun_name = ctx_get_function fun_id ctx in + F.pp_print_string fmt fun_name; + + (* Note that we do not need to print the generics for the trait + declaration: they are always implicit as they can be deduced + from the trait self clause. + + Print the trait ref (to instantate the self clause) *) F.pp_print_space fmt (); - extract_const_generic ctx fmt true cg) - cg_args); + extract_trait_ref ctx fmt TypeDeclId.Set.empty true trait_ref + | _ -> + let fun_name = ctx_get_function fun_id ctx in + F.pp_print_string fmt fun_name); + + (* Sanity check: HOL4 doesn't support const generics *) + assert (generics.const_generics = [] || !backend <> HOL4); + (* Print the generics. + + We might need to filter some of the type arguments, if the type + is builtin (for instance, we filter the global allocator type + argument for `Vec::new`). + *) + let types = + match fun_id with + | FromLlbc (FunId (Regular id), _, _) -> + fun_builtin_filter_types id generics.types ctx + | _ -> Result.Ok generics.types + in + (match types with + | Ok types -> + extract_generic_args ctx fmt TypeDeclId.Set.empty + { generics with types } + | Error (types, err) -> + extract_generic_args ctx fmt TypeDeclId.Set.empty + { generics with types }; + if !Config.fail_hard then raise (Failure err) + else + F.pp_print_string fmt + "(\"ERROR: ill-formed builtin: invalid number of filtering \ + arguments\")"); (* Print the arguments *) List.iter (fun ve -> @@ -2366,9 +504,9 @@ and extract_function_call (ctx : extraction_ctx) (fmt : F.formatter) (** Subcase of the app case: ADT constructor *) and extract_adt_cons (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool) - (adt_cons : adt_cons_id) (type_args : ty list) - (cg_args : const_generic list) (args : texpression list) : unit = - let e_ty = Adt (adt_cons.adt_id, type_args, cg_args) in + (adt_cons : adt_cons_id) (generics : generic_args) (args : texpression list) + : unit = + let e_ty = Adt (adt_cons.adt_id, generics) in let is_single_pat = false in let _ = extract_adt_g_value @@ -2382,7 +520,7 @@ and extract_adt_cons (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool) (** Subcase of the app case: ADT field projector. *) and extract_field_projector (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool) (original_app : texpression) (proj : projection) - (_proj_type_params : ty list) (args : texpression list) : unit = + (_generics : generic_args) (args : texpression list) : unit = (* We isolate the first argument (if there is), in order to pretty print the * projection ([x.field] instead of [MkAdt?.field x] *) match args with @@ -2734,9 +872,7 @@ and extract_StructUpdate (ctx : extraction_ctx) (fmt : F.formatter) let extract_as_unit = match (!backend, supd.struct_id) with | HOL4, AdtId adt_id -> - let d = - TypeDeclId.Map.find adt_id ctx.trans_ctx.type_context.type_decls - in + let d = TypeDeclId.Map.find adt_id ctx.trans_ctx.type_ctx.type_decls in d.kind = Struct [] | _ -> false in @@ -2835,17 +971,17 @@ and extract_StructUpdate (ctx : extraction_ctx) (fmt : F.formatter) F.pp_open_hvbox fmt ctx.indent_incr; let need_paren = inside in if need_paren then F.pp_print_string fmt "("; - (* Open the box for `Array.mk T N [` *) + (* Open the box for `Array.replicate T N [` *) F.pp_open_hovbox fmt ctx.indent_incr; (* Print the array constructor *) - let cs = ctx_get_struct false (Assumed Array) ctx in + let cs = ctx_get_struct (Assumed Array) ctx in F.pp_print_string fmt cs; (* Print the parameters *) - let _, tys, cgs = ty_as_adt e_ty in - let ty = Collections.List.to_cons_nil tys in + let _, generics = ty_as_adt e_ty in + let ty = Collections.List.to_cons_nil generics.types in F.pp_print_space fmt (); extract_ty ctx fmt TypeDeclId.Set.empty true ty; - let cg = Collections.List.to_cons_nil cgs in + let cg = Collections.List.to_cons_nil generics.const_generics in F.pp_print_space fmt (); extract_const_generic ctx fmt true cg; F.pp_print_space fmt (); @@ -2872,17 +1008,15 @@ and extract_StructUpdate (ctx : extraction_ctx) (fmt : F.formatter) F.pp_close_box fmt () | _ -> raise (Failure "Unreachable") -(** Insert a space, if necessary *) -let insert_req_space (fmt : F.formatter) (space : bool ref) : unit = - if !space then space := false else F.pp_print_space fmt () - (** A small utility to print the parameters of a function signature. We return two contexts: - - the context augmented with bindings for the type parameters - - the context augmented with bindings for the type parameters *and* + - the context augmented with bindings for the generics + - the context augmented with bindings for the generics *and* bindings for the input values + We also return names for the type parameters, const generics, etc. + TODO: do we really need the first one? We should probably always use the second one. It comes from the fact that when we print the input values for the @@ -2890,57 +1024,40 @@ let insert_req_space (fmt : F.formatter) (space : bool ref) : unit = patterns, not the variables). We should figure a cleaner way. *) let extract_fun_parameters (space : bool ref) (ctx : extraction_ctx) - (fmt : F.formatter) (def : fun_decl) : extraction_ctx * extraction_ctx = + (fmt : F.formatter) (def : fun_decl) : + extraction_ctx * extraction_ctx * string list = + (* First, add the associated types and constants if the function is a method + in a trait declaration. + + About the order: we want to make sure the names are reserved for + those (variable names might collide with them but it is ok, we will add + suffixes to the variables). + + TODO: micro-pass to update what happens when calling trait provided + functions. + *) + let ctx, trait_decl = + match def.kind with + | TraitMethodProvided (decl_id, _) -> + let trait_decl = T.TraitDeclId.Map.find decl_id ctx.trans_trait_decls in + let ctx, _ = ctx_add_trait_self_clause ctx in + let ctx = { ctx with is_provided_method = true } in + (ctx, Some trait_decl) + | _ -> (ctx, None) + in (* Add the type parameters - note that we need those bindings only for the * body translation (they are not top-level) *) - let ctx, type_params, cg_params = - ctx_add_type_const_generic_params def.signature.type_params - def.signature.const_generic_params ctx + let ctx, type_params, cg_params, trait_clauses = + ctx_add_generic_params def.signature.generics ctx in - (* Print the parameters - rem.: we should have filtered the functions - * with no input parameters *) - (* The type parameters. - - Note that in HOL4 we don't print the type parameters. - *) - if (type_params <> [] || cg_params <> []) && !backend <> HOL4 then ( - (* Open a box for the type and const generic parameters *) - F.pp_open_hovbox fmt 0; - (* The type parameters *) - if type_params <> [] then ( - insert_req_space fmt space; - F.pp_print_string fmt "("; - List.iter - (fun (p : type_var) -> - let pname = ctx_get_type_var p.index ctx in - F.pp_print_string fmt pname; - F.pp_print_space fmt ()) - def.signature.type_params; - F.pp_print_string fmt ":"; - F.pp_print_space fmt (); - let type_keyword = - match !backend with - | FStar -> "Type0" - | Coq | Lean -> "Type" - | HOL4 -> raise (Failure "Unreachable") - in - F.pp_print_string fmt (type_keyword ^ ")")); - (* The const generic parameters *) - if cg_params <> [] then - List.iter - (fun (p : const_generic_var) -> - let pname = ctx_get_const_generic_var p.index ctx in - insert_req_space fmt space; - F.pp_print_string fmt "("; - F.pp_print_string fmt pname; - F.pp_print_space fmt (); - F.pp_print_string fmt ":"; - F.pp_print_space fmt (); - extract_literal_type ctx fmt p.ty; - F.pp_print_string fmt ")") - def.signature.const_generic_params; - (* Close the box for the type parameters *) - F.pp_close_box fmt ()); + (* Print the generics *) + (* Open a box for the generics *) + F.pp_open_hovbox fmt 0; + (let space = Some space in + extract_generic_params ctx fmt TypeDeclId.Set.empty ~space ~trait_decl + def.signature.generics type_params cg_params trait_clauses); + (* Close the box for the generics *) + F.pp_close_box fmt (); (* The input parameters - note that doing this adds bindings to the context *) let ctx_body = match def.body with @@ -2963,7 +1080,7 @@ let extract_fun_parameters (space : bool ref) (ctx : extraction_ctx) ctx) ctx body.inputs_lvs in - (ctx, ctx_body) + (ctx, ctx_body, List.concat [ type_params; cg_params; trait_clauses ]) (** A small utility to print the types of the input parameters in the form: [u32 -> list u32 -> ...] @@ -2982,6 +1099,11 @@ let extract_fun_input_parameters_types (ctx : extraction_ctx) in List.iter extract_param def.signature.inputs +let extract_fun_inputs_output_parameters_types (ctx : extraction_ctx) + (fmt : F.formatter) (def : fun_decl) : unit = + extract_fun_input_parameters_types ctx fmt def; + extract_ty ctx fmt TypeDeclId.Set.empty false def.signature.output + let assert_backend_supports_decreases_clauses () = match !backend with | FStar | Lean -> () @@ -3032,7 +1154,7 @@ let extract_template_fstar_decreases_clause (ctx : extraction_ctx) F.pp_print_space fmt (); (* Extract the parameters *) let space = ref true in - let _, _ = extract_fun_parameters space ctx fmt def in + let _, _, _ = extract_fun_parameters space ctx fmt def in insert_req_space fmt space; F.pp_print_string fmt ":"; (* Print the signature *) @@ -3094,7 +1216,7 @@ let extract_template_lean_termination_and_decreasing (ctx : extraction_ctx) F.pp_print_space fmt (); (* Extract the parameters *) let space = ref true in - let _, ctx_body = extract_fun_parameters space ctx fmt def in + let _, ctx_body, _ = extract_fun_parameters space ctx fmt def in (* Print the ":=" *) F.pp_print_space fmt (); F.pp_print_string fmt ":="; @@ -3164,7 +1286,7 @@ let extract_fun_comment (ctx : extraction_ctx) (fmt : F.formatter) (def : fun_decl) : unit = let { keep_fwd; num_backs } = PureUtils.RegularFunIdMap.find - (A.Regular def.def_id, def.loop_id, def.back_id) + (Pure.FunId (Regular def.def_id), def.loop_id, def.back_id) ctx.fun_name_info in let comment_pre = "[" ^ Print.fun_name_to_string def.basename ^ "]: " in @@ -3205,10 +1327,8 @@ let extract_fun_decl_gen (ctx : extraction_ctx) (fmt : F.formatter) (kind : decl_kind) (has_decreases_clause : bool) (def : fun_decl) : unit = assert (not def.is_global_decl_body); (* Retrieve the function name *) - let with_opaque_pre = false in let def_name = - ctx_get_local_function with_opaque_pre def.def_id def.loop_id def.back_id - ctx + ctx_get_local_function def.def_id def.loop_id def.back_id ctx in (* Add a break before *) if !backend <> HOL4 || not (decl_is_first_from_group kind) then @@ -3234,23 +1354,15 @@ let extract_fun_decl_gen (ctx : extraction_ctx) (fmt : F.formatter) *) let is_opaque_coq = !backend = Coq && is_opaque in let use_forall = - is_opaque_coq - && (def.signature.type_params <> [] - || def.signature.const_generic_params <> []) + is_opaque_coq && def.signature.generics <> empty_generic_params in - (* Print the qualifier ("assume", etc.). - - if `wrap_opaque_in_sig`: we generate a record of assumed funcions. - TODO: this is obsolete. - *) - (if not (!Config.wrap_opaque_in_sig && (kind = Assumed || kind = Declared)) - then - let qualif = ctx.fmt.fun_decl_kind_to_qualif kind in - match qualif with - | Some qualif -> - F.pp_print_string fmt qualif; - F.pp_print_space fmt () - | None -> ()); + (* Print the qualifier ("assume", etc.). *) + let qualif = ctx.fmt.fun_decl_kind_to_qualif kind in + (match qualif with + | Some qualif -> + F.pp_print_string fmt qualif; + F.pp_print_space fmt () + | None -> ()); F.pp_print_string fmt def_name; F.pp_print_space fmt (); if use_forall then ( @@ -3262,7 +1374,7 @@ let extract_fun_decl_gen (ctx : extraction_ctx) (fmt : F.formatter) (* Open a box for "(PARAMS) :" *) F.pp_open_hovbox fmt 0; let space = ref true in - let ctx, ctx_body = extract_fun_parameters space ctx fmt def in + let ctx, ctx_body, all_params = extract_fun_parameters space ctx fmt def in (* Print the return type - note that we have to be careful when * printing the input values for the decrease clause, because * it introduces bindings in the context... We thus "forget" @@ -3310,20 +1422,13 @@ let extract_fun_decl_gen (ctx : extraction_ctx) (fmt : F.formatter) (* The name of the decrease clause *) let decr_name = ctx_get_termination_measure def.def_id def.loop_id ctx in F.pp_print_string fmt decr_name; - (* Print the type/const generic parameters - TODO: we do this many + (* Print the generic parameters - TODO: we do this many times, we should have a helper to factor it out *) List.iter - (fun (p : type_var) -> - let pname = ctx_get_type_var p.index ctx in + (fun (name : string) -> F.pp_print_space fmt (); - F.pp_print_string fmt pname) - def.signature.type_params; - List.iter - (fun (p : const_generic_var) -> - let pname = ctx_get_const_generic_var p.index ctx in - F.pp_print_space fmt (); - F.pp_print_string fmt pname) - def.signature.const_generic_params; + F.pp_print_string fmt name) + all_params; (* Print the input values: we have to be careful here to print * only the input values which are in common with the *forward* * function (the additional input values "given back" to the @@ -3410,19 +1515,12 @@ let extract_fun_decl_gen (ctx : extraction_ctx) (fmt : F.formatter) (* Open the box for [DECREASES] *) F.pp_open_hovbox fmt ctx.indent_incr; F.pp_print_string fmt terminates_name; - (* Print the type/const generic params - TODO: factor out *) + (* Print the generic params - TODO: factor out *) List.iter - (fun (p : type_var) -> - let pname = ctx_get_type_var p.index ctx in + (fun (name : string) -> F.pp_print_space fmt (); - F.pp_print_string fmt pname) - def.signature.type_params; - List.iter - (fun (p : const_generic_var) -> - let pname = ctx_get_const_generic_var p.index ctx in - F.pp_print_space fmt (); - F.pp_print_string fmt pname) - def.signature.const_generic_params; + F.pp_print_string fmt name) + all_params; (* Print the variables *) List.iter (fun v -> @@ -3475,18 +1573,13 @@ let extract_fun_decl_gen (ctx : extraction_ctx) (fmt : F.formatter) let extract_fun_decl_hol4_opaque (ctx : extraction_ctx) (fmt : F.formatter) (def : fun_decl) : unit = (* Retrieve the definition name *) - let with_opaque_pre = false in let def_name = - ctx_get_local_function with_opaque_pre def.def_id def.loop_id def.back_id - ctx + ctx_get_local_function def.def_id def.loop_id def.back_id ctx in - assert (def.signature.const_generic_params = []); + assert (def.signature.generics.const_generics = []); (* Add the type/const gen parameters - note that we need those bindings only for the generation of the type (they are not top-level) *) - let ctx, _, _ = - ctx_add_type_const_generic_params def.signature.type_params - def.signature.const_generic_params ctx - in + let ctx, _, _, _ = ctx_add_generic_params def.signature.generics ctx in (* Add breaks to insert new lines between definitions *) F.pp_print_break fmt 0 0; (* Open a box for the whole definition *) @@ -3635,8 +1728,13 @@ let extract_global_decl_hol4_opaque (ctx : extraction_ctx) (fmt : F.formatter) (* Print the type *) F.pp_open_hovbox fmt 0; extract_ty ctx fmt TypeDeclId.Set.empty false ty; + (* Close the definition *) + F.pp_print_string fmt ")"; + F.pp_close_box fmt (); + (* Close the definition box *) F.pp_close_box fmt (); - (* Close the definition boxe *) F.pp_close_box fmt () + (* Add a line *) + F.pp_print_space fmt () (** Extract a global declaration. @@ -3662,21 +1760,19 @@ let extract_global_decl (ctx : extraction_ctx) (fmt : F.formatter) (global : A.global_decl) (body : fun_decl) (interface : bool) : unit = assert body.is_global_decl_body; assert (Option.is_none body.back_id); - assert (List.length body.signature.inputs = 0); + assert (body.signature.inputs = []); assert (List.length body.signature.doutputs = 1); - assert (List.length body.signature.type_params = 0); - assert (List.length body.signature.const_generic_params = 0); + assert (body.signature.generics = empty_generic_params); (* Add a break then the name of the corresponding LLBC declaration *) F.pp_print_break fmt 0 0; extract_comment fmt [ "[" ^ Print.global_name_to_string global.name ^ "]" ]; F.pp_print_space fmt (); - let with_opaque_pre = false in - let decl_name = ctx_get_global with_opaque_pre global.def_id ctx in + let decl_name = ctx_get_global global.def_id ctx in let body_name = - ctx_get_function with_opaque_pre - (FromLlbc (Regular global.body_id, None, None)) + ctx_get_function + (FromLlbc (Pure.FunId (Regular global.body_id), None, None)) ctx in @@ -3713,6 +1809,807 @@ let extract_global_decl (ctx : extraction_ctx) (fmt : F.formatter) (* Add a break to insert lines between declarations *) F.pp_print_break fmt 0 0 +(** Similar to {!extract_trait_decl_register_names} *) +let extract_trait_decl_register_parent_clause_names (ctx : extraction_ctx) + (trait_decl : trait_decl) + (builtin_info : ExtractBuiltin.builtin_trait_decl_info option) : + extraction_ctx = + (* Compute the clause names *) + let clause_names = + match builtin_info with + | None -> + List.map + (fun (c : trait_clause) -> + let name = ctx.fmt.trait_parent_clause_name trait_decl c in + (* Add a prefix if necessary *) + let name = + if !Config.record_fields_short_names then name + else ctx.fmt.trait_decl_name trait_decl ^ name + in + (c.clause_id, name)) + trait_decl.parent_clauses + | Some info -> + List.map + (fun (c, name) -> (c.clause_id, name)) + (List.combine trait_decl.parent_clauses info.parent_clauses) + in + (* Register the names *) + List.fold_left + (fun ctx (cid, cname) -> + ctx_add (TraitParentClauseId (trait_decl.def_id, cid)) cname ctx) + ctx clause_names + +(** Similar to {!extract_trait_decl_register_names} *) +let extract_trait_decl_register_constant_names (ctx : extraction_ctx) + (trait_decl : trait_decl) + (builtin_info : ExtractBuiltin.builtin_trait_decl_info option) : + extraction_ctx = + let consts = trait_decl.consts in + (* Compute the names *) + let constant_names = + match builtin_info with + | None -> + List.map + (fun (item_name, _) -> + let name = ctx.fmt.trait_const_name trait_decl item_name in + (* Add a prefix if necessary *) + let name = + if !Config.record_fields_short_names then name + else ctx.fmt.trait_decl_name trait_decl ^ name + in + (item_name, name)) + consts + | Some info -> + let const_map = StringMap.of_list info.consts in + List.map + (fun (item_name, _) -> + (item_name, StringMap.find item_name const_map)) + consts + in + (* Register the names *) + List.fold_left + (fun ctx (item_name, name) -> + ctx_add (TraitItemId (trait_decl.def_id, item_name)) name ctx) + ctx constant_names + +(** Similar to {!extract_trait_decl_register_names} *) +let extract_trait_decl_type_names (ctx : extraction_ctx) + (trait_decl : trait_decl) + (builtin_info : ExtractBuiltin.builtin_trait_decl_info option) : + extraction_ctx = + let types = trait_decl.types in + (* Compute the names *) + let type_names = + match builtin_info with + | None -> + let compute_type_name (item_name : string) : string = + let type_name = ctx.fmt.trait_type_name trait_decl item_name in + if !Config.record_fields_short_names then type_name + else ctx.fmt.trait_decl_name trait_decl ^ type_name + in + let compute_clause_name (item_name : string) (clause : trait_clause) : + TraitClauseId.id * string = + let name = + ctx.fmt.trait_type_clause_name trait_decl item_name clause + in + (* Add a prefix if necessary *) + let name = + if !Config.record_fields_short_names then name + else ctx.fmt.trait_decl_name trait_decl ^ name + in + (clause.clause_id, name) + in + List.map + (fun (item_name, (item_clauses, _)) -> + (* Type name *) + let type_name = compute_type_name item_name in + (* Clause names *) + let clauses = + List.map (compute_clause_name item_name) item_clauses + in + (item_name, (type_name, clauses))) + types + | Some info -> + let type_map = StringMap.of_list info.types in + List.map + (fun (item_name, (item_clauses, _)) -> + let type_name, clauses_info = StringMap.find item_name type_map in + let clauses = + List.map + (fun (clause, clause_name) -> (clause.clause_id, clause_name)) + (List.combine item_clauses clauses_info) + in + (item_name, (type_name, clauses))) + types + in + (* Register the names *) + List.fold_left + (fun ctx (item_name, (type_name, clauses)) -> + let ctx = + ctx_add (TraitItemId (trait_decl.def_id, item_name)) type_name ctx + in + List.fold_left + (fun ctx (clause_id, clause_name) -> + ctx_add + (TraitItemClauseId (trait_decl.def_id, item_name, clause_id)) + clause_name ctx) + ctx clauses) + ctx type_names + +(** Similar to {!extract_trait_decl_register_names} *) +let extract_trait_decl_method_names (ctx : extraction_ctx) + (trait_decl : trait_decl) + (builtin_info : ExtractBuiltin.builtin_trait_decl_info option) : + extraction_ctx = + let required_methods = trait_decl.required_methods in + (* Compute the names *) + let method_names = + (* We add one field per required forward/backward function *) + let get_funs_for_id (id : fun_decl_id) : fun_decl list = + let trans : pure_fun_translation = FunDeclId.Map.find id ctx.trans_funs in + List.map (fun f -> f.f) (trans.fwd :: trans.backs) + in + match builtin_info with + | None -> + (* We add one field per required forward/backward function *) + let compute_item_names (item_name : string) (id : fun_decl_id) : + string * (RegionGroupId.id option * string) list = + let compute_fun_name (f : fun_decl) : RegionGroupId.id option * string + = + (* We do something special to reuse the [ctx_compute_fun_decl] + function. TODO: make it cleaner. *) + let basename : name = [ Ident item_name ] in + let f = { f with basename } in + let trans = A.FunDeclId.Map.find f.def_id ctx.trans_funs in + let name = ctx_compute_fun_name trans f ctx in + (* Add a prefix if necessary *) + let name = + if !Config.record_fields_short_names then name + else ctx.fmt.trait_decl_name trait_decl ^ "_" ^ name + in + (f.back_id, name) + in + let funs = get_funs_for_id id in + (item_name, List.map compute_fun_name funs) + in + List.map (fun (name, id) -> compute_item_names name id) required_methods + | Some info -> + let funs_map = StringMap.of_list info.methods in + List.map + (fun (item_name, fun_id) -> + let open ExtractBuiltin in + let info = StringMap.find item_name funs_map in + let trans_funs = get_funs_for_id fun_id in + let find (trans_fun : fun_decl) = + let info = + List.find_opt + (fun (info : builtin_fun_info) -> info.rg = trans_fun.back_id) + info + in + match info with + | Some info -> (info.rg, info.extract_name) + | None -> + let err = + "Ill-formed builtin information for trait decl \"" + ^ Names.name_to_string trait_decl.name + ^ "\", method \"" ^ item_name + ^ "\": could not find name for region " + ^ Print.option_to_string Pure.show_region_group_id + trans_fun.back_id + in + log#serror err; + if !Config.fail_hard then raise (Failure err) + else (trans_fun.back_id, "%ERROR_BUILTIN_NAME_NOT_FOUND%") + in + let rg_with_name_list = List.map find trans_funs in + (item_name, rg_with_name_list)) + required_methods + in + (* Register the names *) + List.fold_left + (fun ctx (item_name, funs) -> + (* We add one field per required forward/backward function *) + List.fold_left + (fun ctx (rg, fun_name) -> + ctx_add + (TraitMethodId (trait_decl.def_id, item_name, rg)) + fun_name ctx) + ctx funs) + ctx method_names + +(** Similar to {!extract_type_decl_register_names} *) +let extract_trait_decl_register_names (ctx : extraction_ctx) + (trait_decl : trait_decl) : extraction_ctx = + (* Lookup the information if this is a builtin trait *) + let open ExtractBuiltin in + let sname = name_to_simple_name trait_decl.name in + let builtin_info = + SimpleNameMap.find_opt sname (builtin_trait_decls_map ()) + in + let ctx = + let trait_name, trait_constructor = + match builtin_info with + | None -> + ( ctx.fmt.trait_decl_name trait_decl, + ctx.fmt.trait_decl_constructor trait_decl ) + | Some info -> (info.extract_name, info.constructor) + in + let ctx = ctx_add (TraitDeclId trait_decl.def_id) trait_name ctx in + ctx_add (TraitDeclConstructorId trait_decl.def_id) trait_constructor ctx + in + (* Parent clauses *) + let ctx = + extract_trait_decl_register_parent_clause_names ctx trait_decl builtin_info + in + (* Constants *) + let ctx = + extract_trait_decl_register_constant_names ctx trait_decl builtin_info + in + (* Types *) + let ctx = extract_trait_decl_type_names ctx trait_decl builtin_info in + (* Required methods *) + let ctx = extract_trait_decl_method_names ctx trait_decl builtin_info in + ctx + +(** Similar to {!extract_type_decl_register_names} *) +let extract_trait_impl_register_names (ctx : extraction_ctx) + (trait_impl : trait_impl) : extraction_ctx = + let decl_id = trait_impl.impl_trait.trait_decl_id in + let trait_decl = TraitDeclId.Map.find decl_id ctx.trans_trait_decls in + (* Check if the trait implementation is builtin *) + let builtin_info = + let open ExtractBuiltin in + let type_sname = name_to_simple_name trait_impl.name in + let trait_sname = name_to_simple_name trait_decl.name in + SimpleNamePairMap.find_opt (type_sname, trait_sname) + (builtin_trait_impls_map ()) + in + (* Register some builtin information (if necessary) *) + let ctx, builtin_info = + match builtin_info with + | None -> (ctx, None) + | Some (filter, info) -> + let ctx = + match filter with + | None -> ctx + | Some filter -> + { + ctx with + trait_impls_filter_type_args_map = + TraitImplId.Map.add trait_impl.def_id filter + ctx.trait_impls_filter_type_args_map; + } + in + (ctx, Some info) + in + + (* For now we do not support overriding provided methods *) + assert (trait_impl.provided_methods = []); + (* Everything is taken care of by {!extract_trait_decl_register_names} *but* + the name of the implementation itself *) + (* Compute the name *) + let name = + match builtin_info with + | None -> ctx.fmt.trait_impl_name trait_decl trait_impl + | Some name -> name + in + ctx_add (TraitImplId trait_impl.def_id) name ctx + +(** Small helper. + + The type `ty` is to be understood in a very general sense. + *) +let extract_trait_item (ctx : extraction_ctx) (fmt : F.formatter) + (item_name : string) (separator : string) (ty : unit -> unit) : unit = + F.pp_print_space fmt (); + F.pp_open_hovbox fmt ctx.indent_incr; + F.pp_print_string fmt item_name; + F.pp_print_space fmt (); + (* ":" or "=" *) + F.pp_print_string fmt separator; + ty (); + (match !Config.backend with Lean -> () | _ -> F.pp_print_string fmt ";"); + F.pp_close_box fmt () + +let extract_trait_decl_item (ctx : extraction_ctx) (fmt : F.formatter) + (item_name : string) (ty : unit -> unit) : unit = + extract_trait_item ctx fmt item_name ":" ty + +let extract_trait_impl_item (ctx : extraction_ctx) (fmt : F.formatter) + (item_name : string) (ty : unit -> unit) : unit = + let assign = match !Config.backend with Lean | Coq -> ":=" | _ -> "=" in + extract_trait_item ctx fmt item_name assign ty + +(** Small helper - TODO: move *) +let generic_params_drop_prefix ~(drop_trait_clauses : bool) + (g1 : generic_params) (g2 : generic_params) : generic_params = + let open Collections.List in + let types = drop (length g1.types) g2.types in + let const_generics = drop (length g1.const_generics) g2.const_generics in + let trait_clauses = + if drop_trait_clauses then drop (length g1.trait_clauses) g2.trait_clauses + else g2.trait_clauses + in + { types; const_generics; trait_clauses } + +(** Small helper. + + Extract the items for a method in a trait decl. + *) +let extract_trait_decl_method_items (ctx : extraction_ctx) (fmt : F.formatter) + (decl : trait_decl) (item_name : string) (id : fun_decl_id) : unit = + (* Lookup the definition *) + let trans = A.FunDeclId.Map.find id ctx.trans_funs in + (* Extract the items *) + let funs = if trans.keep_fwd then trans.fwd :: trans.backs else trans.backs in + let extract_method (f : fun_and_loops) = + let f = f.f in + let fun_name = ctx_get_trait_method decl.def_id item_name f.back_id ctx in + let ty () = + (* Extract the generics *) + (* We need to add the generics specific to the method, by removing those + which actually apply to the trait decl *) + let generics = + let drop_trait_clauses = false in + generic_params_drop_prefix ~drop_trait_clauses decl.generics + f.signature.generics + in + let ctx, type_params, cg_params, trait_clauses = + ctx_add_generic_params generics ctx + in + let backend_uses_forall = + match !backend with Coq | Lean -> true | FStar | HOL4 -> false + in + let generics_not_empty = generics <> empty_generic_params in + let use_forall = generics_not_empty && backend_uses_forall in + let use_arrows = generics_not_empty && not backend_uses_forall in + let use_forall_use_sep = false in + extract_generic_params ctx fmt TypeDeclId.Set.empty ~use_forall + ~use_forall_use_sep ~use_arrows generics type_params cg_params + trait_clauses; + if use_forall then F.pp_print_string fmt ","; + (* Extract the inputs and output *) + F.pp_print_space fmt (); + extract_fun_inputs_output_parameters_types ctx fmt f + in + extract_trait_decl_item ctx fmt fun_name ty + in + List.iter extract_method funs + +(** Extract a trait declaration *) +let extract_trait_decl (ctx : extraction_ctx) (fmt : F.formatter) + (decl : trait_decl) : unit = + (* Retrieve the trait name *) + let decl_name = ctx_get_trait_decl decl.def_id ctx in + (* Add a break before *) + F.pp_print_break fmt 0 0; + (* Print a comment to link the extracted type to its original rust definition *) + extract_comment fmt + [ "Trait declaration: [" ^ Print.name_to_string decl.name ^ "]" ]; + F.pp_print_break fmt 0 0; + (* Open two outer boxes for the definition, so that whenever possible it gets printed on + one line and indents are correct. + + There is just an exception with Lean: in this backend, line breaks are important + for the parsing, so we always open a vertical box. + *) + if !Config.backend = Lean then F.pp_open_vbox fmt ctx.indent_incr + else ( + F.pp_open_hvbox fmt 0; + F.pp_open_hvbox fmt ctx.indent_incr); + + (* `struct Trait (....) =` *) + (* Open the box for the name + generics *) + F.pp_open_hovbox fmt ctx.indent_incr; + let qualif = + Option.get (ctx.fmt.type_decl_kind_to_qualif SingleNonRec (Some Struct)) + in + (* When checking if the trait declaration is empty: we ignore the provided + methods, because for now they are extracted separately *) + let is_empty = trait_decl_is_empty { decl with provided_methods = [] } in + if !backend = FStar && not is_empty then ( + F.pp_print_string fmt "noeq"; + F.pp_print_space fmt ()); + F.pp_print_string fmt qualif; + F.pp_print_space fmt (); + F.pp_print_string fmt decl_name; + (* Print the generics *) + let generics = decl.generics in + (* Add the type and const generic params - note that we need those bindings only for the + * body translation (they are not top-level) *) + let ctx, type_params, cg_params, trait_clauses = + ctx_add_generic_params generics ctx + in + extract_generic_params ctx fmt TypeDeclId.Set.empty generics type_params + cg_params trait_clauses; + + F.pp_print_space fmt (); + if is_empty && !backend = FStar then ( + F.pp_print_string fmt "= unit"; + (* Outer box *) + F.pp_close_box fmt ()) + else if is_empty && !backend = Coq then ( + (* Coq is not very good at infering constructors *) + let cons = ctx_get_trait_constructor decl.def_id ctx in + F.pp_print_string fmt (":= " ^ cons ^ "{}."); + (* Outer box *) + F.pp_close_box fmt ()) + else ( + (match !backend with + | Lean -> F.pp_print_string fmt "where" + | FStar -> F.pp_print_string fmt "= {" + | Coq -> + let cons = ctx_get_trait_constructor decl.def_id ctx in + F.pp_print_string fmt (":= " ^ cons ^ " {") + | _ -> F.pp_print_string fmt "{"); + + (* Close the box for the name + generics *) + F.pp_close_box fmt (); + + (* + * Extract the items + *) + + (* The constants *) + List.iter + (fun (name, (ty, _)) -> + let item_name = ctx_get_trait_const decl.def_id name ctx in + let ty () = + let inside = false in + F.pp_print_space fmt (); + extract_ty ctx fmt TypeDeclId.Set.empty inside ty + in + extract_trait_decl_item ctx fmt item_name ty) + decl.consts; + + (* The types *) + List.iter + (fun (name, (clauses, _)) -> + (* Extract the type *) + let item_name = ctx_get_trait_type decl.def_id name ctx in + let ty () = + F.pp_print_space fmt (); + F.pp_print_string fmt (type_keyword ()) + in + extract_trait_decl_item ctx fmt item_name ty; + (* Extract the clauses *) + List.iter + (fun clause -> + let item_name = + ctx_get_trait_item_clause decl.def_id name clause.clause_id ctx + in + let ty () = + F.pp_print_space fmt (); + extract_trait_clause_type ctx fmt TypeDeclId.Set.empty clause + in + extract_trait_decl_item ctx fmt item_name ty) + clauses) + decl.types; + + (* The parent clauses - note that the parent clauses may refer to the types + and const generics: for this reason we extract them *after* *) + List.iter + (fun clause -> + let item_name = + ctx_get_trait_parent_clause decl.def_id clause.clause_id ctx + in + let ty () = + F.pp_print_space fmt (); + extract_trait_clause_type ctx fmt TypeDeclId.Set.empty clause + in + extract_trait_decl_item ctx fmt item_name ty) + decl.parent_clauses; + + (* The required methods *) + List.iter + (fun (name, id) -> extract_trait_decl_method_items ctx fmt decl name id) + decl.required_methods; + + (* Close the outer boxes for the definition *) + if !Config.backend <> Lean then F.pp_close_box fmt (); + (* Close the brackets *) + match !Config.backend with + | Lean -> () + | Coq -> + F.pp_print_space fmt (); + F.pp_print_string fmt "}." + | _ -> + F.pp_print_space fmt (); + F.pp_print_string fmt "}"); + F.pp_close_box fmt (); + (* Add breaks to insert new lines between definitions *) + F.pp_print_break fmt 0 0 + +(** Generate the [Arguments] instructions for the trait declarationsin Coq, so + that we don't have to provide the implicit arguments when projecting the fields. *) +let extract_trait_decl_coq_arguments (ctx : extraction_ctx) (fmt : F.formatter) + (decl : trait_decl) : unit = + (* Generating the [Arguments] instructions is useful only if there are parameters *) + let num_params = + List.length decl.generics.types + + List.length decl.generics.const_generics + + List.length decl.generics.trait_clauses + in + if num_params > 0 then ( + (* The constructor *) + let cons_name = ctx_get_trait_constructor decl.def_id ctx in + extract_coq_arguments_instruction ctx fmt cons_name num_params; + (* The constants *) + List.iter + (fun (name, _) -> + let item_name = ctx_get_trait_const decl.def_id name ctx in + extract_coq_arguments_instruction ctx fmt item_name num_params) + decl.consts; + (* The types *) + List.iter + (fun (name, (clauses, _)) -> + (* The type *) + let item_name = ctx_get_trait_type decl.def_id name ctx in + extract_coq_arguments_instruction ctx fmt item_name num_params; + (* The type clauses *) + List.iter + (fun clause -> + let item_name = + ctx_get_trait_item_clause decl.def_id name clause.clause_id ctx + in + extract_coq_arguments_instruction ctx fmt item_name num_params) + clauses) + decl.types; + (* The parent clauses *) + List.iter + (fun clause -> + let item_name = + ctx_get_trait_parent_clause decl.def_id clause.clause_id ctx + in + extract_coq_arguments_instruction ctx fmt item_name num_params) + decl.parent_clauses; + (* The required methods *) + List.iter + (fun (item_name, id) -> + (* Lookup the definition *) + let trans = A.FunDeclId.Map.find id ctx.trans_funs in + (* Extract the items *) + let funs = + if trans.keep_fwd then trans.fwd :: trans.backs else trans.backs + in + let extract_for_method (f : fun_and_loops) = + let f = f.f in + let item_name = + ctx_get_trait_method decl.def_id item_name f.back_id ctx + in + extract_coq_arguments_instruction ctx fmt item_name num_params + in + List.iter extract_for_method funs) + decl.required_methods; + (* Add a space *) + F.pp_print_space fmt ()) + +(** See {!extract_trait_decl_coq_arguments} *) +let extract_trait_decl_extra_info (ctx : extraction_ctx) (fmt : F.formatter) + (trait_decl : trait_decl) : unit = + match !backend with + | Coq -> extract_trait_decl_coq_arguments ctx fmt trait_decl + | _ -> () + +(** Small helper. + + Extract the items for a method in a trait impl. + *) +let extract_trait_impl_method_items (ctx : extraction_ctx) (fmt : F.formatter) + (impl : trait_impl) (item_name : string) (id : fun_decl_id) + (impl_generics : string list * string list * string list) : unit = + let trait_decl_id = impl.impl_trait.trait_decl_id in + (* Lookup the definition *) + let trans = A.FunDeclId.Map.find id ctx.trans_funs in + (* Extract the items *) + let funs = if trans.keep_fwd then trans.fwd :: trans.backs else trans.backs in + let extract_method (f : fun_and_loops) = + let f = f.f in + let fun_name = ctx_get_trait_method trait_decl_id item_name f.back_id ctx in + let ty () = + (* Filter the generics if the method is a builtin *) + let i_tys, _, _ = impl_generics in + let impl_types, i_tys, f_tys = + match FunDeclId.Map.find_opt f.def_id ctx.funs_filter_type_args_map with + | None -> (impl.generics.types, i_tys, f.signature.generics.types) + | Some filter -> + let filter_list filter ls = + let ls = List.combine filter ls in + List.filter_map (fun (b, ty) -> if b then Some ty else None) ls + in + let impl_types = impl.generics.types in + let impl_filter = + Collections.List.prefix (List.length impl_types) filter + in + let i_tys = i_tys in + let i_filter = Collections.List.prefix (List.length i_tys) filter in + ( filter_list impl_filter impl_types, + filter_list i_filter i_tys, + filter_list filter f.signature.generics.types ) + in + let f_generics = { f.signature.generics with types = f_tys } in + (* Extract the generics - we need to quantify over the generics which + are specific to the method, and call it will all the generics + (trait impl + method generics) *) + let f_generics = + let drop_trait_clauses = true in + generic_params_drop_prefix ~drop_trait_clauses + { impl.generics with types = impl_types } + f_generics + in + (* Register and print the quantified generics *) + let ctx, f_tys, f_cgs, f_tcs = ctx_add_generic_params f_generics ctx in + let use_forall = f_generics <> empty_generic_params in + extract_generic_params ctx fmt TypeDeclId.Set.empty ~use_forall f_generics + f_tys f_cgs f_tcs; + if use_forall then F.pp_print_string fmt ","; + (* Extract the function call *) + F.pp_print_space fmt (); + let fun_name = ctx_get_local_function f.def_id None f.back_id ctx in + F.pp_print_string fmt fun_name; + let all_generics = + let _, i_cgs, i_tcs = impl_generics in + List.concat [ i_tys; f_tys; i_cgs; f_cgs; i_tcs; f_tcs ] + in + + (* Filter the generics if the function is builtin *) + List.iter + (fun p -> + F.pp_print_space fmt (); + F.pp_print_string fmt p) + all_generics + in + extract_trait_impl_item ctx fmt fun_name ty + in + List.iter extract_method funs + +(** Extract a trait implementation *) +let extract_trait_impl (ctx : extraction_ctx) (fmt : F.formatter) + (impl : trait_impl) : unit = + log#ldebug (lazy ("extract_trait_impl: " ^ Names.name_to_string impl.name)); + (* Retrieve the impl name *) + let impl_name = ctx_get_trait_impl impl.def_id ctx in + (* Add a break before *) + F.pp_print_break fmt 0 0; + (* Print a comment to link the extracted type to its original rust definition *) + extract_comment fmt + [ "Trait implementation: [" ^ Print.name_to_string impl.name ^ "]" ]; + F.pp_print_break fmt 0 0; + + (* Open two outer boxes for the definition, so that whenever possible it gets printed on + one line and indents are correct. + + There is just an exception with Lean: in this backend, line breaks are important + for the parsing, so we always open a vertical box. + *) + if !Config.backend = Lean then ( + F.pp_open_vbox fmt 0; + F.pp_open_vbox fmt ctx.indent_incr) + else ( + F.pp_open_hvbox fmt 0; + F.pp_open_hvbox fmt ctx.indent_incr); + + (* `let (....) : Trait ... =` *) + (* Open the box for the name + generics *) + F.pp_open_hovbox fmt ctx.indent_incr; + (match ctx.fmt.fun_decl_kind_to_qualif SingleNonRec with + | Some qualif -> + F.pp_print_string fmt qualif; + F.pp_print_space fmt () + | None -> ()); + F.pp_print_string fmt impl_name; + + (* Print the generics *) + (* Add the type and const generic params - note that we need those bindings only for the + * body translation (they are not top-level) *) + let ctx, type_params, cg_params, trait_clauses = + ctx_add_generic_params impl.generics ctx + in + let all_generics = (type_params, cg_params, trait_clauses) in + extract_generic_params ctx fmt TypeDeclId.Set.empty impl.generics type_params + cg_params trait_clauses; + + (* Print the type *) + F.pp_print_space fmt (); + F.pp_print_string fmt ":"; + F.pp_print_space fmt (); + extract_trait_decl_ref ctx fmt TypeDeclId.Set.empty false impl.impl_trait; + + (* When checking if the trait impl is empty: we ignore the provided + methods, because for now they are extracted separately *) + let is_empty = trait_impl_is_empty { impl with provided_methods = [] } in + + F.pp_print_space fmt (); + if is_empty && !Config.backend = FStar then ( + F.pp_print_string fmt "= ()"; + (* Outer box *) + F.pp_close_box fmt ()) + else if is_empty && !Config.backend = Coq then ( + (* Coq is not very good at infering constructors *) + let cons = ctx_get_trait_constructor impl.impl_trait.trait_decl_id ctx in + F.pp_print_string fmt (":= " ^ cons ^ "."); + (* Outer box *) + F.pp_close_box fmt ()) + else ( + if !Config.backend = Lean then F.pp_print_string fmt ":= {" + else if !Config.backend = Coq then F.pp_print_string fmt ":= {|" + else F.pp_print_string fmt "= {"; + + (* Close the box for the name + generics *) + F.pp_close_box fmt (); + + (* + * Extract the items + *) + let trait_decl_id = impl.impl_trait.trait_decl_id in + + (* The constants *) + List.iter + (fun (name, (_, id)) -> + let item_name = ctx_get_trait_const trait_decl_id name ctx in + let ty () = + F.pp_print_space fmt (); + F.pp_print_string fmt (ctx_get_global id ctx) + in + + extract_trait_impl_item ctx fmt item_name ty) + impl.consts; + + (* The types *) + List.iter + (fun (name, (trait_refs, ty)) -> + (* Extract the type *) + let item_name = ctx_get_trait_type trait_decl_id name ctx in + let ty () = + F.pp_print_space fmt (); + extract_ty ctx fmt TypeDeclId.Set.empty false ty + in + extract_trait_impl_item ctx fmt item_name ty; + (* Extract the clauses *) + TraitClauseId.iteri + (fun clause_id trait_ref -> + let item_name = + ctx_get_trait_item_clause trait_decl_id name clause_id ctx + in + let ty () = + F.pp_print_space fmt (); + extract_trait_ref ctx fmt TypeDeclId.Set.empty false trait_ref + in + extract_trait_impl_item ctx fmt item_name ty) + trait_refs) + impl.types; + + (* The parent clauses *) + TraitClauseId.iteri + (fun clause_id trait_ref -> + let item_name = + ctx_get_trait_parent_clause trait_decl_id clause_id ctx + in + let ty () = + F.pp_print_space fmt (); + extract_trait_ref ctx fmt TypeDeclId.Set.empty false trait_ref + in + extract_trait_impl_item ctx fmt item_name ty) + impl.parent_trait_refs; + + (* The required methods *) + List.iter + (fun (name, id) -> + extract_trait_impl_method_items ctx fmt impl name id all_generics) + impl.required_methods; + + (* Close the outer boxes for the definition, as well as the brackets *) + F.pp_close_box fmt (); + if !backend = Coq then ( + F.pp_print_space fmt (); + F.pp_print_string fmt "|}.") + else if (not (!backend = FStar)) || not is_empty then ( + F.pp_print_space fmt (); + F.pp_print_string fmt "}")); + F.pp_close_box fmt (); + (* Add breaks to insert new lines between definitions *) + F.pp_print_break fmt 0 0 + (** Extract a unit test, if the function is a unit function (takes no parameters, returns unit). @@ -3735,8 +2632,7 @@ let extract_unit_test_if_unit_fun (ctx : extraction_ctx) (fmt : F.formatter) (* Check if this is a unit function *) let sg = def.signature in if - sg.type_params = [] - && sg.const_generic_params = [] + sg.generics = empty_generic_params && (sg.inputs = [ mk_unit_ty ] || sg.inputs = []) && sg.output = mk_result_ty mk_unit_ty then ( @@ -3756,12 +2652,8 @@ let extract_unit_test_if_unit_fun (ctx : extraction_ctx) (fmt : F.formatter) F.pp_print_string fmt "assert_norm"; F.pp_print_space fmt (); F.pp_print_string fmt "("; - (* Note that if the function is opaque, the unit test will fail - because the normalizer will get stuck *) - let with_opaque_pre = ctx.use_opaque_pre in let fun_name = - ctx_get_local_function with_opaque_pre def.def_id def.loop_id - def.back_id ctx + ctx_get_local_function def.def_id def.loop_id def.back_id ctx in F.pp_print_string fmt fun_name; if sg.inputs <> [] then ( @@ -3776,12 +2668,8 @@ let extract_unit_test_if_unit_fun (ctx : extraction_ctx) (fmt : F.formatter) F.pp_print_string fmt "Check"; F.pp_print_space fmt (); F.pp_print_string fmt "("; - (* Note that if the function is opaque, the unit test will fail - because the normalizer will get stuck *) - let with_opaque_pre = ctx.use_opaque_pre in let fun_name = - ctx_get_local_function with_opaque_pre def.def_id def.loop_id - def.back_id ctx + ctx_get_local_function def.def_id def.loop_id def.back_id ctx in F.pp_print_string fmt fun_name; if sg.inputs <> [] then ( @@ -3793,12 +2681,8 @@ let extract_unit_test_if_unit_fun (ctx : extraction_ctx) (fmt : F.formatter) F.pp_print_string fmt "#assert"; F.pp_print_space fmt (); F.pp_print_string fmt "("; - (* Note that if the function is opaque, the unit test will fail - because the normalizer will get stuck *) - let with_opaque_pre = ctx.use_opaque_pre in let fun_name = - ctx_get_local_function with_opaque_pre def.def_id def.loop_id - def.back_id ctx + ctx_get_local_function def.def_id def.loop_id def.back_id ctx in F.pp_print_string fmt fun_name; if sg.inputs <> [] then ( @@ -3812,12 +2696,8 @@ let extract_unit_test_if_unit_fun (ctx : extraction_ctx) (fmt : F.formatter) | HOL4 -> F.pp_print_string fmt "val _ = assert_return ("; F.pp_print_string fmt "“"; - (* Note that if the function is opaque, the unit test will fail - because the normalizer will get stuck *) - let with_opaque_pre = ctx.use_opaque_pre in let fun_name = - ctx_get_local_function with_opaque_pre def.def_id def.loop_id - def.back_id ctx + ctx_get_local_function def.def_id def.loop_id def.back_id ctx in F.pp_print_string fmt fun_name; if sg.inputs <> [] then ( diff --git a/compiler/ExtractBase.ml b/compiler/ExtractBase.ml index d733c763..31b1a447 100644 --- a/compiler/ExtractBase.ml +++ b/compiler/ExtractBase.ml @@ -5,9 +5,10 @@ open TranslateCore module C = Contexts module RegionVarId = T.RegionVarId module F = Format +open ExtractBuiltin (** The local logger *) -let log = L.pure_to_extract_log +let log = L.extract_log type region_group_info = { id : RegionGroupId.id; @@ -21,8 +22,8 @@ type region_group_info = { *) } -module StringSet = Collections.MakeSet (Collections.OrderedString) -module StringMap = Collections.MakeMap (Collections.OrderedString) +module StringSet = Collections.StringSet +module StringMap = Collections.StringMap type name = Names.name type type_name = Names.type_name @@ -77,6 +78,7 @@ type decl_kind = F*: [val x : Type0] Coq: [Axiom x : Type.] *) +[@@deriving show] (** Return [true] if the declaration is the last from its group of declarations. @@ -111,9 +113,9 @@ let decl_is_first_from_group (kind : decl_kind) : bool = let decl_is_not_last_from_group (kind : decl_kind) : bool = not (decl_is_last_from_group kind) -(* TODO: this should a module we give to a functor! *) +type type_decl_kind = Enum | Struct [@@deriving show] -type type_decl_kind = Enum | Struct +(* TODO: this should be a module we give to a functor! *) (** A formatter's role is twofold: 1. Come up with name suggestions. @@ -125,6 +127,9 @@ type type_decl_kind = Enum | Struct snake case, adding prefixes/suffixes, etc. 2. Format some specific terms, like constants. + + TODO: unclear that this is useful now that all the backends are so much + entangled in Extract.ml *) type formatter = { bool_name : string; @@ -239,37 +244,14 @@ type formatter = { the same purpose as in {!field:fun_name}. - loop identifier, if this is for a loop *) - opaque_pre : unit -> string; - (** TODO: obsolete, remove. - - The prefix to use for opaque definitions. - - We need this because for some backends like Lean and Coq, we group - opaque definitions in module signatures, meaning that using those - definitions requires to prefix them with a module parameter name (such - as "opaque_defs."). - - For instance, if we have an opaque function [f : int -> int], which - is used by the non-opaque function [g], we would generate (in Coq): - {[ - (* The module signature declaring the opaque definitions *) - module type OpaqueDefs = { - f_fwd : int -> int - ... (* Other definitions *) - } - - (* The definitions generated for the non-opaque definitions *) - module Funs (opaque: OpaqueDefs) = { - let g ... = - ... - opaque_defs.f_fwd - ... - } - ]} - - Upon using [f] in [g], we don't directly use the the name "f_fwd", - but prefix it with the "opaque_defs." identifier. - *) + trait_decl_name : trait_decl -> string; + trait_impl_name : trait_decl -> trait_impl -> string; + trait_decl_constructor : trait_decl -> string; + trait_parent_clause_name : trait_decl -> trait_clause -> string; + trait_const_name : trait_decl -> string -> string; + trait_type_name : trait_decl -> string -> string; + trait_method_name : trait_decl -> string -> string; + trait_type_clause_name : trait_decl -> string -> trait_clause -> string; var_basename : StringSet.t -> string option -> ty -> string; (** Generates a variable basename. @@ -288,6 +270,14 @@ type formatter = { (** Generates a type variable basename. *) const_generic_var_basename : StringSet.t -> string -> string; (** Generates a const generic variable basename. *) + trait_self_clause_basename : string; + trait_clause_basename : StringSet.t -> trait_clause -> string; + (** Return a base name for a trait clause. We might add a suffix to prevent + collisions. + + In the traduction we explicitely manipulate the trait clause instances, + that is we introduce one input variable for each trait clause. + *) append_index : string -> int -> string; (** Appends an index to a name - we use this to generate unique names: when doing so, the role of the formatter is just to concatenate @@ -396,10 +386,60 @@ type id = | TypeVarId of TypeVarId.id | ConstGenericVarId of ConstGenericVarId.id | VarId of VarId.id + | TraitDeclId of TraitDeclId.id + | TraitImplId of TraitImplId.id + | LocalTraitClauseId of TraitClauseId.id + | TraitDeclConstructorId of TraitDeclId.id + | TraitMethodId of TraitDeclId.id * string * T.RegionGroupId.id option + (** Something peculiar with trait methods: because we have to take into + account forward/backward functions, we may need to generate fields + items per method. + *) + | TraitItemId of TraitDeclId.id * string + (** A trait associated item which is not a method *) + | TraitParentClauseId of TraitDeclId.id * TraitClauseId.id + | TraitItemClauseId of TraitDeclId.id * string * TraitClauseId.id + | TraitSelfClauseId + (** Specifically for the clause: [Self : Trait]. + + For now, we forbid provided methods (methods in trait declarations + with a default implementation) from being overriden in trait implementations. + We extract trait provided methods such that they take an instance of + the trait as input: this instance is given by the trait self clause. + + For instance: + {[ + // + // Rust + // + trait ToU64 { + fn to_u64(&self) -> u64; + + // Provided method + fn is_pos(&self) -> bool { + self.to_u64() > 0 + } + } + + // + // Generated code + // + struct ToU64 (T : Type) { + to_u64 : T -> u64; + } + + // The trait self clause + // vvvvvvvvvvvvvvvvvvvvvv + let is_pos (T : Type) (trait_self : ToU64 T) (self : T) : bool = + trait_self.to_u64 self > 0 + ]} + *) | UnknownId (** Used for stored various strings like keywords, definitions which should always be in context, etc. and which can't be linked to one of the above. + + TODO: rename to "keyword" *) [@@deriving show, ord] @@ -429,69 +469,64 @@ type names_map = { precisely which identifiers are mapped to the same name... *) names_set : StringSet.t; - opaque_ids : IdSet.t; - (** TODO: this is obsolete. Remove. +} - The set of opaque definitions. +let empty_names_map : names_map = + { + id_to_name = IdMap.empty; + name_to_id = StringMap.empty; + names_set = StringSet.empty; + } - See {!formatter.opaque_pre} for detailed explanations about why - we need to know which definitions are opaque to compute names. +(** Small helper to report name collision *) +let report_name_collision (id_to_string : id -> string) (id1 : id) (id2 : id) + (name : string) : unit = + let id1 = "\n- " ^ id_to_string id1 in + let id2 = "\n- " ^ id_to_string id2 in + let err = + "Name clash detected: the following identifiers are bound to the same name \ + \"" ^ name ^ "\":" ^ id1 ^ id2 + ^ "\nYou may want to rename some of your definitions, or report an issue." + in + log#serror err; + (* If we fail hard on errors, raise an exception *) + if !Config.fail_hard then raise (Failure err) - Also note that the opaque ids don't contain the ids of the assumed - definitions. In practice, assumed definitions are opaque_defs. However, they - are not grouped in the opaque module, meaning we never need to - prefix them (with, say, "opaque_defs."): we thus consider them as non-opaque - with regards to the names map. - *) -} +let names_map_get_id_from_name (name : string) (nm : names_map) : id option = + StringMap.find_opt name nm.name_to_id -let names_map_add (id_to_string : id -> string) (is_opaque : bool) (id : id) - (name : string) (nm : names_map) : names_map = - (* Check if there is a clash *) - (match StringMap.find_opt name nm.name_to_id with +let names_map_check_collision (id_to_string : id -> string) (id : id) + (name : string) (nm : names_map) : unit = + match names_map_get_id_from_name name nm with | None -> () (* Ok *) | Some clash -> (* There is a clash: print a nice debugging message for the user *) - let id1 = "\n- " ^ id_to_string clash in - let id2 = "\n- " ^ id_to_string id in - let err = - "Name clash detected: the following identifiers are bound to the same \ - name \"" ^ name ^ "\":" ^ id1 ^ id2 - in - log#serror err; - raise (Failure err)); - (* Sanity check *) - assert (not (StringSet.mem name nm.names_set)); + report_name_collision id_to_string clash id name + +(** Insert bindings in a names map without checking for collisions *) +let names_map_add_unchecked (id : id) (name : string) (nm : names_map) : + names_map = (* Insert *) let id_to_name = IdMap.add id name nm.id_to_name in let name_to_id = StringMap.add name id nm.name_to_id in let names_set = StringSet.add name nm.names_set in - let opaque_ids = - if is_opaque then IdSet.add id nm.opaque_ids else nm.opaque_ids - in - { id_to_name; name_to_id; names_set; opaque_ids } - -let names_map_add_assumed_type (id_to_string : id -> string) (id : assumed_ty) - (name : string) (nm : names_map) : names_map = - let is_opaque = false in - names_map_add id_to_string is_opaque (TypeId (Assumed id)) name nm - -let names_map_add_assumed_struct (id_to_string : id -> string) (id : assumed_ty) - (name : string) (nm : names_map) : names_map = - let is_opaque = false in - names_map_add id_to_string is_opaque (StructId (Assumed id)) name nm + { id_to_name; name_to_id; names_set } -let names_map_add_assumed_variant (id_to_string : id -> string) - (id : assumed_ty) (variant_id : VariantId.id) (name : string) +let names_map_add (id_to_string : id -> string) (id : id) (name : string) (nm : names_map) : names_map = - let is_opaque = false in - names_map_add id_to_string is_opaque - (VariantId (Assumed id, variant_id)) - name nm - -let names_map_add_function (id_to_string : id -> string) (is_opaque : bool) - (fid : fun_id) (name : string) (nm : names_map) : names_map = - names_map_add id_to_string is_opaque (FunId fid) name nm + (* Check if there is a clash *) + names_map_check_collision id_to_string id name nm; + (* Sanity check *) + if StringSet.mem name nm.names_set then ( + let err = + "Error when registering the name for id: " ^ id_to_string id + ^ ":\nThe chosen name is already in the names set: " ^ name + in + log#serror err; + (* If we fail hard on errors, raise an exception *) + if !Config.fail_hard then raise (Failure err)); + (* Insert *) + names_map_add_unchecked id name nm (** The unsafe names map stores mappings from identifiers to names which might collide. For some backends and some names, it might be acceptable to have @@ -503,6 +538,8 @@ let names_map_add_function (id_to_string : id -> string) (is_opaque : bool) *) type unsafe_names_map = { id_to_name : string IdMap.t } +let empty_unsafe_names_map = { id_to_name = IdMap.empty } + let unsafe_names_map_add (id : id) (name : string) (nm : unsafe_names_map) : unsafe_names_map = { id_to_name = IdMap.add id name nm.id_to_name } @@ -541,6 +578,24 @@ let basename_to_unique (names_set : StringSet.t) type fun_name_info = { keep_fwd : bool; num_backs : int } +type names_maps = { + names_map : names_map; + (** The map for id to names, where we forbid name collisions + (ex.: we always forbid function name collisions). *) + unsafe_names_map : unsafe_names_map; + (** The map for id to names, where we allow name collisions + (ex.: we might allow record field name collisions). *) + strict_names_map : names_map; + (** This map is a sub-map of [names_map]. For the ids in this map we also + forbid collisions with names in the [unsafe_names_map]. + + We do so for keywords for instance, but also for types (in a dependently + typed language, we might have an issue if the field of a record has, say, + the name "u32", and another field of the same record refers to "u32" + (for instance in its type). + *) +} + (** Extraction context. Note that the extraction context contains information coming from the @@ -549,24 +604,12 @@ type fun_name_info = { keep_fwd : bool; num_backs : int } functions, etc. *) type extraction_ctx = { + crate : A.crate; trans_ctx : trans_ctx; - names_map : names_map; - (** The map for id to names, where we forbid name collisions - (ex.: we always forbid function name collisions). *) - unsafe_names_map : unsafe_names_map; - (** The map for id to names, where we allow name collisions - (ex.: we might allow record field name collisions). *) + names_maps : names_maps; fmt : formatter; indent_incr : int; (** The indent increment we insert whenever we need to indent more *) - use_opaque_pre : bool; - (** Do we use the "opaque_defs." prefix for the opaque definitions? - - Opaque function definitions might refer opaque types: if we are in the - opaque module, we musn't use the "opaque_defs." prefix, otherwise we - use it. - Also see {!names_map.opaque_ids}. - *) use_dep_ite : bool; (** For Lean: do we use dependent-if then else expressions? @@ -586,6 +629,29 @@ type extraction_ctx = { in case a Rust function only has one backward translation and we filter the forward function because it returns unit. *) + trait_decl_id : trait_decl_id option; + (** If we are extracting a trait declaration, identifies it *) + is_provided_method : bool; + trans_types : Pure.type_decl Pure.TypeDeclId.Map.t; + trans_funs : pure_fun_translation A.FunDeclId.Map.t; + functions_with_decreases_clause : PureUtils.FunLoopIdSet.t; + trans_trait_decls : Pure.trait_decl Pure.TraitDeclId.Map.t; + trans_trait_impls : Pure.trait_impl Pure.TraitImplId.Map.t; + types_filter_type_args_map : bool list TypeDeclId.Map.t; + (** The map to filter the type arguments for the builtin type + definitions. + + We need this for type `Vec`, for instance, which takes a useless + (in the context of the type translation) type argument for the + allocator which is used, and which we want to remove. + + TODO: it would be cleaner to filter those types in a micro-pass, + rather than at code generation time. + *) + funs_filter_type_args_map : bool list FunDeclId.Map.t; + (** Same as {!types_filter_type_args_map}, but for functions *) + trait_impls_filter_type_args_map : bool list TraitImplId.Map.t; + (** Same as {!types_filter_type_args_map}, but for trait implementations *) } (** Debugging function, used when communicating name collisions to the user, @@ -593,9 +659,16 @@ type extraction_ctx = { instance). *) let id_to_string (id : id) (ctx : extraction_ctx) : string = - let global_decls = ctx.trans_ctx.global_context.global_decls in - let fun_decls = ctx.trans_ctx.fun_context.fun_decls in - let type_decls = ctx.trans_ctx.type_context.type_decls in + let global_decls = ctx.trans_ctx.global_ctx.global_decls in + let fun_decls = ctx.trans_ctx.fun_ctx.fun_decls in + let type_decls = ctx.trans_ctx.type_ctx.type_decls in + let trait_decls = ctx.trans_ctx.trait_decls_ctx.trait_decls in + let trait_decl_id_to_string (id : A.TraitDeclId.id) : string = + let trait_name = + Print.fun_name_to_string (A.TraitDeclId.Map.find id trait_decls).name + in + "trait_decl: " ^ trait_name ^ " (id: " ^ A.TraitDeclId.to_string id ^ ")" + in (* TODO: factorize the pretty-printing with what is in PrintPure *) let get_type_name (id : type_id) : string = match id with @@ -614,10 +687,17 @@ let id_to_string (id : id) (ctx : extraction_ctx) : string = | FromLlbc (fid, lp_id, rg_id) -> let fun_name = match fid with - | Regular fid -> + | FunId (Regular fid) -> Print.fun_name_to_string (A.FunDeclId.Map.find fid fun_decls).name - | Assumed aid -> A.show_assumed_fun_id aid + | FunId (Assumed aid) -> A.show_assumed_fun_id aid + | TraitMethod (trait_ref, method_name, _) -> + (* Shouldn't happen *) + if !Config.fail_hard then raise (Failure "Unexpected") + else + "Trait method: decl: " + ^ TraitDeclId.to_string trait_ref.trait_decl_ref.trait_decl_id + ^ ", method_name: " ^ method_name in let lp_kind = @@ -673,12 +753,16 @@ let id_to_string (id : id) (ctx : extraction_ctx) : string = if variant_id = error_failure_id then "@error::Failure" else if variant_id = error_out_of_fuel_id then "@error::OutOfFuel" else raise (Failure "Unreachable") - | Assumed Option -> - if variant_id = option_some_id then "@option::Some" - else if variant_id = option_none_id then "@option::None" + | Assumed Fuel -> + if variant_id = fuel_zero_id then "@fuel::0" + else if variant_id = fuel_succ_id then "@fuel::Succ" else raise (Failure "Unreachable") - | Assumed (State | Vec | Fuel | Array | Slice | Str | Range) -> - raise (Failure "Unreachable") + | Assumed (State | Array | Slice | Str | RawPtr _) -> + raise + (Failure + ("Unreachable: variant id (" + ^ VariantId.to_string variant_id + ^ ") for " ^ show_type_id id)) | AdtId id -> ( let def = TypeDeclId.Map.find id type_decls in match def.kind with @@ -693,8 +777,7 @@ let id_to_string (id : id) (ctx : extraction_ctx) : string = match id with | Tuple -> raise (Failure "Unreachable") | Assumed - ( State | Result | Error | Fuel | Option | Vec | Array | Slice | Str - | Range ) -> + (State | Result | Error | Fuel | Array | Slice | Str | RawPtr _) -> (* We can't directly have access to the fields of those types *) raise (Failure "Unreachable") | AdtId id -> ( @@ -716,134 +799,265 @@ let id_to_string (id : id) (ctx : extraction_ctx) : string = | ConstGenericVarId id -> "const_generic_var_id: " ^ ConstGenericVarId.to_string id | VarId id -> "var_id: " ^ VarId.to_string id + | TraitDeclId id -> "trait_decl_id: " ^ TraitDeclId.to_string id + | TraitImplId id -> "trait_impl_id: " ^ TraitImplId.to_string id + | LocalTraitClauseId id -> + "local_trait_clause_id: " ^ TraitClauseId.to_string id + | TraitDeclConstructorId id -> + "trait_decl_constructor: " ^ trait_decl_id_to_string id + | TraitParentClauseId (id, clause_id) -> + "trait_parent_clause_id: " ^ trait_decl_id_to_string id ^ ", clause_id: " + ^ TraitClauseId.to_string clause_id + | TraitItemClauseId (id, item_name, clause_id) -> + "trait_item_clause_id: " ^ trait_decl_id_to_string id ^ ", item name: " + ^ item_name ^ ", clause_id: " + ^ TraitClauseId.to_string clause_id + | TraitItemId (id, name) -> + "trait_item_id: " ^ trait_decl_id_to_string id ^ ", type name: " ^ name + | TraitMethodId (trait_decl_id, fun_name, rg_id) -> + let fwd_back_kind = + match rg_id with + | None -> "forward" + | Some rg_id -> "backward " ^ RegionGroupId.to_string rg_id + in + trait_decl_id_to_string trait_decl_id + ^ ", method name (" ^ fwd_back_kind ^ "): " ^ fun_name + | TraitSelfClauseId -> "trait_self_clause" + +(** Return [true] if we are strict on collisions for this id (i.e., we forbid + collisions even with the ids in the unsafe names map) *) +let strict_collisions (id : id) : bool = + match id with UnknownId | TypeId _ -> true | _ -> false (** We might not check for collisions for some specific ids (ex.: field names) *) let allow_collisions (id : id) : bool = match id with - | FieldId (_, _) -> !Config.record_fields_short_names + | FieldId _ | TraitItemClauseId _ | TraitParentClauseId _ | TraitItemId _ + | TraitMethodId _ -> + !Config.record_fields_short_names + | FunId (Pure _ | FromLlbc (FunId (Assumed _), _, _)) -> + (* We map several assumed functions to the same id *) + true | _ -> false -let ctx_add (is_opaque : bool) (id : id) (name : string) (ctx : extraction_ctx) - : extraction_ctx = - (* We do not use the same name map if we allow/disallow collisions *) +(** The [id_to_string] function to print nice debugging messages if there are + collisions *) +let names_maps_add (id_to_string : id -> string) (id : id) (name : string) + (nm : names_maps) : names_maps = + (* We do not use the same name map if we allow/disallow collisions. + We notably use it for field names: some backends like Lean can use the + type information to disambiguate field projections. + + Remark: we still need to check that those "unsafe" ids don't collide with + the ids that we mark as "strict on collision". + + For instance, we don't allow naming a field "let". We enforce this by + not checking collision between ids for which we permit collisions (ex.: + between fields), but still checking collisions between those ids and the + others (ex.: fields and keywords). + *) if allow_collisions id then ( - assert (not is_opaque); + (* Check with the ids which are considered to be strict on collisions *) + names_map_check_collision id_to_string id name nm.strict_names_map; { - ctx with - unsafe_names_map = unsafe_names_map_add id name ctx.unsafe_names_map; + nm with + unsafe_names_map = unsafe_names_map_add id name nm.unsafe_names_map; }) else - (* The id_to_string function to print nice debugging messages if there are - * collisions *) - let id_to_string (id : id) : string = id_to_string id ctx in - let names_map = - names_map_add id_to_string is_opaque id name ctx.names_map + (* Remark: if we are strict on collisions: + - we add the id to the strict collisions map + - we check that the id doesn't collide with the unsafe map + TODO: we might not check that: + - a user defined function doesn't collide with an assumed function + - two trait decl items don't collide with each other + *) + let strict_names_map = + if strict_collisions id then + names_map_add id_to_string id name nm.strict_names_map + else nm.strict_names_map in - { ctx with names_map } + let names_map = names_map_add id_to_string id name nm.names_map in + { nm with strict_names_map; names_map } + +let ctx_add (id : id) (name : string) (ctx : extraction_ctx) : extraction_ctx = + let id_to_string (id : id) : string = id_to_string id ctx in + let names_maps = names_maps_add id_to_string id name ctx.names_maps in + { ctx with names_maps } -(** [with_opaque_pre]: if [true] and the definition is opaque, add the opaque prefix *) -let ctx_get (with_opaque_pre : bool) (id : id) (ctx : extraction_ctx) : string = +(** The [id_to_string] function to print nice debugging messages if there are + collisions *) +let names_maps_get (id_to_string : id -> string) (id : id) (nm : names_maps) : + string = (* We do not use the same name map if we allow/disallow collisions *) - if allow_collisions id then IdMap.find id ctx.unsafe_names_map.id_to_name + let map_to_string (m : string IdMap.t) : string = + "[\n" + ^ String.concat "," + (List.map + (fun (id, n) -> "\n " ^ id_to_string id ^ " -> " ^ n) + (IdMap.bindings m)) + ^ "\n]" + in + if allow_collisions id then ( + let m = nm.unsafe_names_map.id_to_name in + match IdMap.find_opt id m with + | Some s -> s + | None -> + let err = + "Could not find: " ^ id_to_string id ^ "\nNames map:\n" + ^ map_to_string m + in + log#serror err; + if !Config.fail_hard then raise (Failure err) + else "(%%%ERROR: unknown identifier\": " ^ id_to_string id ^ "\"%%%)") else - match IdMap.find_opt id ctx.names_map.id_to_name with - | Some s -> - let is_opaque = IdSet.mem id ctx.names_map.opaque_ids in - if with_opaque_pre && is_opaque then ctx.fmt.opaque_pre () ^ s else s + let m = nm.names_map.id_to_name in + match IdMap.find_opt id m with + | Some s -> s | None -> - log#serror ("Could not find: " ^ id_to_string id ctx); - raise Not_found + let err = + "Could not find: " ^ id_to_string id ^ "\nNames map:\n" + ^ map_to_string m + in + log#serror err; + if !Config.fail_hard then raise (Failure err) + else "(ERROR: \"" ^ id_to_string id ^ "\")" + +let ctx_get (id : id) (ctx : extraction_ctx) : string = + let id_to_string (id : id) : string = id_to_string id ctx in + names_maps_get id_to_string id ctx.names_maps + +let names_maps_add_assumed_type (id_to_string : id -> string) (id : assumed_ty) + (name : string) (nm : names_maps) : names_maps = + names_maps_add id_to_string (TypeId (Assumed id)) name nm + +let names_maps_add_assumed_struct (id_to_string : id -> string) + (id : assumed_ty) (name : string) (nm : names_maps) : names_maps = + names_maps_add id_to_string (StructId (Assumed id)) name nm -let ctx_get_global (with_opaque_pre : bool) (id : A.GlobalDeclId.id) +let names_maps_add_assumed_variant (id_to_string : id -> string) + (id : assumed_ty) (variant_id : VariantId.id) (name : string) + (nm : names_maps) : names_maps = + names_maps_add id_to_string (VariantId (Assumed id, variant_id)) name nm + +let names_maps_add_function (id_to_string : id -> string) (fid : fun_id) + (name : string) (nm : names_maps) : names_maps = + names_maps_add id_to_string (FunId fid) name nm + +let ctx_get_global (id : A.GlobalDeclId.id) (ctx : extraction_ctx) : string = + ctx_get (GlobalId id) ctx + +let ctx_get_function (id : fun_id) (ctx : extraction_ctx) : string = + ctx_get (FunId id) ctx + +let ctx_get_local_function (id : A.FunDeclId.id) (lp : LoopId.id option) + (rg : RegionGroupId.id option) (ctx : extraction_ctx) : string = + ctx_get_function (FromLlbc (FunId (Regular id), lp, rg)) ctx + +let ctx_get_type (id : type_id) (ctx : extraction_ctx) : string = + assert (id <> Tuple); + ctx_get (TypeId id) ctx + +let ctx_get_local_type (id : TypeDeclId.id) (ctx : extraction_ctx) : string = + ctx_get_type (AdtId id) ctx + +let ctx_get_assumed_type (id : assumed_ty) (ctx : extraction_ctx) : string = + ctx_get_type (Assumed id) ctx + +let ctx_get_trait_constructor (id : trait_decl_id) (ctx : extraction_ctx) : + string = + ctx_get (TraitDeclConstructorId id) ctx + +let ctx_get_trait_self_clause (ctx : extraction_ctx) : string = + ctx_get TraitSelfClauseId ctx + +let ctx_get_trait_decl (id : trait_decl_id) (ctx : extraction_ctx) : string = + ctx_get (TraitDeclId id) ctx + +let ctx_get_trait_impl (id : trait_impl_id) (ctx : extraction_ctx) : string = + ctx_get (TraitImplId id) ctx + +let ctx_get_trait_item (id : trait_decl_id) (item_name : string) (ctx : extraction_ctx) : string = - ctx_get with_opaque_pre (GlobalId id) ctx + ctx_get (TraitItemId (id, item_name)) ctx -let ctx_get_function (with_opaque_pre : bool) (id : fun_id) +let ctx_get_trait_const (id : trait_decl_id) (item_name : string) (ctx : extraction_ctx) : string = - ctx_get with_opaque_pre (FunId id) ctx + ctx_get_trait_item id item_name ctx -let ctx_get_local_function (with_opaque_pre : bool) (id : A.FunDeclId.id) - (lp : LoopId.id option) (rg : RegionGroupId.id option) +let ctx_get_trait_type (id : trait_decl_id) (item_name : string) (ctx : extraction_ctx) : string = - ctx_get_function with_opaque_pre (FromLlbc (Regular id, lp, rg)) ctx + ctx_get_trait_item id item_name ctx -let ctx_get_type (with_opaque_pre : bool) (id : type_id) (ctx : extraction_ctx) - : string = - assert (id <> Tuple); - ctx_get with_opaque_pre (TypeId id) ctx +let ctx_get_trait_method (id : trait_decl_id) (item_name : string) + (rg_id : T.RegionGroupId.id option) (ctx : extraction_ctx) : string = + ctx_get (TraitMethodId (id, item_name, rg_id)) ctx -let ctx_get_local_type (with_opaque_pre : bool) (id : TypeDeclId.id) +let ctx_get_trait_parent_clause (id : trait_decl_id) (clause : trait_clause_id) (ctx : extraction_ctx) : string = - ctx_get_type with_opaque_pre (AdtId id) ctx + ctx_get (TraitParentClauseId (id, clause)) ctx -let ctx_get_assumed_type (id : assumed_ty) (ctx : extraction_ctx) : string = - (* In practice, the assumed types are opaque. However, assumed types - are never grouped in the opaque module, meaning we never need to - prefix them: we thus consider them as non-opaque with regards to the - names map. - *) - let is_opaque = false in - ctx_get_type is_opaque (Assumed id) ctx +let ctx_get_trait_item_clause (id : trait_decl_id) (item : string) + (clause : trait_clause_id) (ctx : extraction_ctx) : string = + ctx_get (TraitItemClauseId (id, item, clause)) ctx let ctx_get_var (id : VarId.id) (ctx : extraction_ctx) : string = - let is_opaque = false in - ctx_get is_opaque (VarId id) ctx + ctx_get (VarId id) ctx let ctx_get_type_var (id : TypeVarId.id) (ctx : extraction_ctx) : string = - let is_opaque = false in - ctx_get is_opaque (TypeVarId id) ctx + ctx_get (TypeVarId id) ctx let ctx_get_const_generic_var (id : ConstGenericVarId.id) (ctx : extraction_ctx) : string = - let is_opaque = false in - ctx_get is_opaque (ConstGenericVarId id) ctx + ctx_get (ConstGenericVarId id) ctx + +let ctx_get_local_trait_clause (id : TraitClauseId.id) (ctx : extraction_ctx) : + string = + ctx_get (LocalTraitClauseId id) ctx let ctx_get_field (type_id : type_id) (field_id : FieldId.id) (ctx : extraction_ctx) : string = - let is_opaque = false in - ctx_get is_opaque (FieldId (type_id, field_id)) ctx + ctx_get (FieldId (type_id, field_id)) ctx -let ctx_get_struct (with_opaque_pre : bool) (def_id : type_id) - (ctx : extraction_ctx) : string = - ctx_get with_opaque_pre (StructId def_id) ctx +let ctx_get_struct (def_id : type_id) (ctx : extraction_ctx) : string = + ctx_get (StructId def_id) ctx let ctx_get_variant (def_id : type_id) (variant_id : VariantId.id) (ctx : extraction_ctx) : string = - let is_opaque = false in - ctx_get is_opaque (VariantId (def_id, variant_id)) ctx + ctx_get (VariantId (def_id, variant_id)) ctx let ctx_get_decreases_proof (def_id : A.FunDeclId.id) (loop_id : LoopId.id option) (ctx : extraction_ctx) : string = - let is_opaque = false in - ctx_get is_opaque (DecreasesProofId (Regular def_id, loop_id)) ctx + ctx_get (DecreasesProofId (Regular def_id, loop_id)) ctx let ctx_get_termination_measure (def_id : A.FunDeclId.id) (loop_id : LoopId.id option) (ctx : extraction_ctx) : string = - let is_opaque = false in - ctx_get is_opaque (TerminationMeasureId (Regular def_id, loop_id)) ctx + ctx_get (TerminationMeasureId (Regular def_id, loop_id)) ctx (** Generate a unique type variable name and add it to the context *) let ctx_add_type_var (basename : string) (id : TypeVarId.id) (ctx : extraction_ctx) : extraction_ctx * string = - let is_opaque = false in - let name = ctx.fmt.type_var_basename ctx.names_map.names_set basename in let name = - basename_to_unique ctx.names_map.names_set ctx.fmt.append_index name + ctx.fmt.type_var_basename ctx.names_maps.names_map.names_set basename + in + let name = + basename_to_unique ctx.names_maps.names_map.names_set ctx.fmt.append_index + name in - let ctx = ctx_add is_opaque (TypeVarId id) name ctx in + let ctx = ctx_add (TypeVarId id) name ctx in (ctx, name) (** Generate a unique const generic variable name and add it to the context *) let ctx_add_const_generic_var (basename : string) (id : ConstGenericVarId.id) (ctx : extraction_ctx) : extraction_ctx * string = - let is_opaque = false in let name = - ctx.fmt.const_generic_var_basename ctx.names_map.names_set basename + ctx.fmt.const_generic_var_basename ctx.names_maps.names_map.names_set + basename in let name = - basename_to_unique ctx.names_map.names_set ctx.fmt.append_index name + basename_to_unique ctx.names_maps.names_map.names_set ctx.fmt.append_index + name in - let ctx = ctx_add is_opaque (ConstGenericVarId id) name ctx in + let ctx = ctx_add (ConstGenericVarId id) name ctx in (ctx, name) (** See {!ctx_add_type_var} *) @@ -856,11 +1070,31 @@ let ctx_add_type_vars (vars : (string * TypeVarId.id) list) (** Generate a unique variable name and add it to the context *) let ctx_add_var (basename : string) (id : VarId.id) (ctx : extraction_ctx) : extraction_ctx * string = - let is_opaque = false in let name = - basename_to_unique ctx.names_map.names_set ctx.fmt.append_index basename + basename_to_unique ctx.names_maps.names_map.names_set ctx.fmt.append_index + basename in - let ctx = ctx_add is_opaque (VarId id) name ctx in + let ctx = ctx_add (VarId id) name ctx in + (ctx, name) + +(** Generate a unique variable name for the trait self clause and add it to the context *) +let ctx_add_trait_self_clause (ctx : extraction_ctx) : extraction_ctx * string = + let basename = ctx.fmt.trait_self_clause_basename in + let name = + basename_to_unique ctx.names_maps.names_map.names_set ctx.fmt.append_index + basename + in + let ctx = ctx_add TraitSelfClauseId name ctx in + (ctx, name) + +(** Generate a unique trait clause name and add it to the context *) +let ctx_add_local_trait_clause (basename : string) (id : TraitClauseId.id) + (ctx : extraction_ctx) : extraction_ctx * string = + let name = + basename_to_unique ctx.names_maps.names_map.names_set ctx.fmt.append_index + basename + in + let ctx = ctx_add (LocalTraitClauseId id) name ctx in (ctx, name) (** See {!ctx_add_var} *) @@ -868,7 +1102,9 @@ let ctx_add_vars (vars : var list) (ctx : extraction_ctx) : extraction_ctx * string list = List.fold_left_map (fun ctx (v : var) -> - let name = ctx.fmt.var_basename ctx.names_map.names_set v.basename v.ty in + let name = + ctx.fmt.var_basename ctx.names_maps.names_map.names_set v.basename v.ty + in ctx_add_var name v.id ctx) ctx vars @@ -885,142 +1121,105 @@ let ctx_add_const_generic_params (vars : const_generic_var list) ctx_add_const_generic_var var.name var.index ctx) ctx vars -let ctx_add_type_const_generic_params (tvars : type_var list) - (cgvars : const_generic_var list) (ctx : extraction_ctx) : - extraction_ctx * string list * string list = - let ctx, tys = ctx_add_type_params tvars ctx in - let ctx, cgs = ctx_add_const_generic_params cgvars ctx in - (ctx, tys, cgs) - -let ctx_add_type_decl_struct (def : type_decl) (ctx : extraction_ctx) : - extraction_ctx * string = - assert (match def.kind with Struct _ -> true | _ -> false); - let is_opaque = false in - let cons_name = ctx.fmt.struct_constructor def.name in - let ctx = ctx_add is_opaque (StructId (AdtId def.def_id)) cons_name ctx in - (ctx, cons_name) - -let ctx_add_type_decl (def : type_decl) (ctx : extraction_ctx) : extraction_ctx - = - let is_opaque = def.kind = Opaque in - let def_name = ctx.fmt.type_name def.name in - let ctx = ctx_add is_opaque (TypeId (AdtId def.def_id)) def_name ctx in - ctx - -let ctx_add_field (def : type_decl) (field_id : FieldId.id) (field : field) - (ctx : extraction_ctx) : extraction_ctx * string = - let is_opaque = false in - let name = ctx.fmt.field_name def.name field_id field.field_name in - let ctx = ctx_add is_opaque (FieldId (AdtId def.def_id, field_id)) name ctx in - (ctx, name) - -let ctx_add_fields (def : type_decl) (fields : (FieldId.id * field) list) +let ctx_add_local_trait_clauses (clauses : trait_clause list) (ctx : extraction_ctx) : extraction_ctx * string list = List.fold_left_map - (fun ctx (vid, v) -> ctx_add_field def vid v ctx) - ctx fields - -let ctx_add_variant (def : type_decl) (variant_id : VariantId.id) - (variant : variant) (ctx : extraction_ctx) : extraction_ctx * string = - let is_opaque = false in - let name = ctx.fmt.variant_name def.name variant.variant_name in - (* Add the type name prefix for Lean *) - let name = - if !Config.backend = Lean then - let type_name = ctx.fmt.type_name def.name in - type_name ^ "." ^ name - else name - in - let ctx = - ctx_add is_opaque (VariantId (AdtId def.def_id, variant_id)) name ctx - in - (ctx, name) - -let ctx_add_variants (def : type_decl) - (variants : (VariantId.id * variant) list) (ctx : extraction_ctx) : - extraction_ctx * string list = - List.fold_left_map - (fun ctx (vid, v) -> ctx_add_variant def vid v ctx) - ctx variants + (fun ctx (c : trait_clause) -> + let basename = + ctx.fmt.trait_clause_basename ctx.names_maps.names_map.names_set c + in + ctx_add_local_trait_clause basename c.clause_id ctx) + ctx clauses -let ctx_add_struct (def : type_decl) (ctx : extraction_ctx) : - extraction_ctx * string = - assert (match def.kind with Struct _ -> true | _ -> false); - let is_opaque = false in - let name = ctx.fmt.struct_constructor def.name in - let ctx = ctx_add is_opaque (StructId (AdtId def.def_id)) name ctx in - (ctx, name) +(** Returns the lists of names for: + - the type variables + - the const generic variables + - the trait clauses + *) +let ctx_add_generic_params (generics : generic_params) (ctx : extraction_ctx) : + extraction_ctx * string list * string list * string list = + let { types; const_generics; trait_clauses } = generics in + let ctx, tys = ctx_add_type_params types ctx in + let ctx, cgs = ctx_add_const_generic_params const_generics ctx in + let ctx, tcs = ctx_add_local_trait_clauses trait_clauses ctx in + (ctx, tys, cgs, tcs) let ctx_add_decreases_proof (def : fun_decl) (ctx : extraction_ctx) : extraction_ctx = - let is_opaque = false in let name = ctx.fmt.decreases_proof_name def.def_id def.basename def.num_loops def.loop_id in - ctx_add is_opaque - (DecreasesProofId (Regular def.def_id, def.loop_id)) - name ctx + ctx_add (DecreasesProofId (Regular def.def_id, def.loop_id)) name ctx let ctx_add_termination_measure (def : fun_decl) (ctx : extraction_ctx) : extraction_ctx = - let is_opaque = false in let name = ctx.fmt.termination_measure_name def.def_id def.basename def.num_loops def.loop_id in - ctx_add is_opaque - (TerminationMeasureId (Regular def.def_id, def.loop_id)) - name ctx + ctx_add (TerminationMeasureId (Regular def.def_id, def.loop_id)) name ctx let ctx_add_global_decl_and_body (def : A.global_decl) (ctx : extraction_ctx) : extraction_ctx = (* TODO: update once the body id can be an option *) - let is_opaque = false in - let name = ctx.fmt.global_name def.name in let decl = GlobalId def.def_id in - let body = FunId (FromLlbc (Regular def.body_id, None, None)) in - let ctx = ctx_add is_opaque decl (name ^ "_c") ctx in - let ctx = ctx_add is_opaque body (name ^ "_body") ctx in - ctx -let ctx_add_fun_decl (trans_group : bool * pure_fun_translation) - (def : fun_decl) (ctx : extraction_ctx) : extraction_ctx = - (* Sanity check: the function should not be a global body - those are handled - * separately *) - assert (not def.is_global_decl_body); + (* Check if the global corresponds to an assumed global that we should map + to a custom definition in our standard library (for instance, happens + with "core::num::usize::MAX") *) + let sname = name_to_simple_name def.name in + match SimpleNameMap.find_opt sname builtin_globals_map with + | Some name -> + (* Yes: register the custom binding *) + ctx_add decl name ctx + | None -> + (* Not the case: "standard" registration *) + let name = ctx.fmt.global_name def.name in + let body = FunId (FromLlbc (FunId (Regular def.body_id), None, None)) in + let ctx = ctx_add decl (name ^ "_c") ctx in + let ctx = ctx_add body (name ^ "_body") ctx in + ctx + +let ctx_compute_fun_name (trans_group : pure_fun_translation) (def : fun_decl) + (ctx : extraction_ctx) : string = (* Lookup the LLBC def to compute the region group information *) let def_id = def.def_id in - let llbc_def = - A.FunDeclId.Map.find def_id ctx.trans_ctx.fun_context.fun_decls - in + let llbc_def = A.FunDeclId.Map.find def_id ctx.trans_ctx.fun_ctx.fun_decls in let sg = llbc_def.signature in let num_rgs = List.length sg.regions_hierarchy in - let keep_fwd, (_, backs) = trans_group in + let { keep_fwd; fwd = _; backs } = trans_group in let num_backs = List.length backs in let rg_info = match def.back_id with | None -> None | Some rg_id -> let rg = T.RegionGroupId.nth sg.regions_hierarchy rg_id in - let regions = + let region_names = List.map - (fun rid -> T.RegionVarId.nth sg.region_params rid) + (fun rid -> (T.RegionVarId.nth sg.generics.regions rid).name) rg.regions in - let region_names = - List.map (fun (r : T.region_var) -> r.name) regions - in Some { id = rg_id; region_names } in - let is_opaque = def.body = None in (* Add the function name *) - let def_name = - ctx.fmt.fun_name def.basename def.num_loops def.loop_id num_rgs rg_info - (keep_fwd, num_backs) - in - let fun_id = (A.Regular def_id, def.loop_id, def.back_id) in - let ctx = ctx_add is_opaque (FunId (FromLlbc fun_id)) def_name ctx in + ctx.fmt.fun_name def.basename def.num_loops def.loop_id num_rgs rg_info + (keep_fwd, num_backs) + +(* TODO: move to Extract *) +let ctx_add_fun_decl (trans_group : pure_fun_translation) (def : fun_decl) + (ctx : extraction_ctx) : extraction_ctx = + (* Sanity check: the function should not be a global body - those are handled + * separately *) + assert (not def.is_global_decl_body); + (* Lookup the LLBC def to compute the region group information *) + let def_id = def.def_id in + let { keep_fwd; fwd = _; backs } = trans_group in + let num_backs = List.length backs in + (* Add the function name *) + let def_name = ctx_compute_fun_name trans_group def ctx in + let fun_id = (Pure.FunId (Regular def_id), def.loop_id, def.back_id) in + let ctx = ctx_add (FunId (FromLlbc fun_id)) def_name ctx in (* Add the name info *) { ctx with @@ -1039,9 +1238,10 @@ type names_map_init = { assumed_pure_functions : (pure_assumed_fun_id * string) list; } -(** Initialize a names map with a proper set of keywords/names coming from the +(** Initialize names maps with a proper set of keywords/names coming from the target language/prover. *) -let initialize_names_map (fmt : formatter) (init : names_map_init) : names_map = +let initialize_names_maps (fmt : formatter) (init : names_map_init) : names_maps + = let int_names = List.map fmt.int_name T.all_int_types in let keywords = List.concat @@ -1049,20 +1249,30 @@ let initialize_names_map (fmt : formatter) (init : names_map_init) : names_map = [ fmt.bool_name; fmt.char_name; fmt.str_name ]; int_names; init.keywords; ] in - let names_set = StringSet.of_list keywords in - let name_to_id = - StringMap.of_list (List.map (fun x -> (x, UnknownId)) keywords) - in - let opaque_ids = IdSet.empty in + let names_set = StringSet.empty in + let name_to_id = StringMap.empty in (* We fist initialize [id_to_name] as empty, because the id of a keyword is [UnknownId]. * Also note that we don't need this mapping for keywords: we insert keywords only * to check collisions. *) let id_to_name = IdMap.empty in - let nm = { id_to_name; name_to_id; names_set; opaque_ids } in + let names_map = { id_to_name; name_to_id; names_set } in + let unsafe_names_map = empty_unsafe_names_map in + let strict_names_map = empty_names_map in (* For debugging - we are creating bindings for assumed types and functions, so * it is ok if we simply use the "show" function (those aren't simply identified * by numbers) *) let id_to_string = show_id in + (* Add the keywords as strict collisions *) + let strict_names_map = + List.fold_left + (fun nm name -> + (* There is duplication in the keywords so we don't check the collisions + while registering them (what is important is that there are no collisions + between keywords and user-defined identifiers) *) + names_map_add_unchecked UnknownId name nm) + strict_names_map keywords + in + let nm = { names_map; unsafe_names_map; strict_names_map } in (* Then we add: * - the assumed types * - the assumed struct constructors @@ -1072,37 +1282,31 @@ let initialize_names_map (fmt : formatter) (init : names_map_init) : names_map = let nm = List.fold_left (fun nm (type_id, name) -> - names_map_add_assumed_type id_to_string type_id name nm) + names_maps_add_assumed_type id_to_string type_id name nm) nm init.assumed_adts in let nm = List.fold_left (fun nm (type_id, name) -> - names_map_add_assumed_struct id_to_string type_id name nm) + names_maps_add_assumed_struct id_to_string type_id name nm) nm init.assumed_structs in let nm = List.fold_left (fun nm (type_id, variant_id, name) -> - names_map_add_assumed_variant id_to_string type_id variant_id name nm) + names_maps_add_assumed_variant id_to_string type_id variant_id name nm) nm init.assumed_variants in let assumed_functions = List.map - (fun (fid, rg, name) -> (FromLlbc (A.Assumed fid, None, rg), name)) + (fun (fid, rg, name) -> + (FromLlbc (Pure.FunId (Assumed fid), None, rg), name)) init.assumed_llbc_functions @ List.map (fun (fid, name) -> (Pure fid, name)) init.assumed_pure_functions in let nm = - (* In practice, the assumed function are opaque. However, assumed functions - are never grouped in the opaque module, meaning we never need to - prefix them: we thus consider them as non-opaque with regards to the - names map. - *) - let is_opaque = false in List.fold_left - (fun nm (fid, name) -> - names_map_add_function id_to_string is_opaque fid name nm) + (fun nm (fid, name) -> names_maps_add_function id_to_string fid name nm) nm assumed_functions in (* Return *) @@ -1150,22 +1354,20 @@ let default_fun_suffix (num_loops : int) (loop_id : LoopId.id option) let rg_suff = (* TODO: make all the backends match what is done for Lean *) match rg with - | None -> ( - match !Config.backend with - | FStar | Coq | HOL4 -> "_fwd" - | Lean -> - (* In order to avoid name conflicts: - * - if the forward is eliminated, we add the suffix "_fwd" (it won't be used) - * - otherwise, no suffix (because the backward functions will have a suffix) - *) - if num_backs = 1 && not keep_fwd then "_fwd" else "") + | None -> + if + (* In order to avoid name conflicts: + * - if the forward is eliminated, we add the suffix "_fwd" (it won't be used) + * - otherwise, no suffix (because the backward functions will have a suffix) + *) + num_backs = 1 && not keep_fwd + then "_fwd" + else "" | Some rg -> assert (num_region_groups > 0 && num_backs > 0); if num_backs = 1 then (* Exactly one backward function *) - match !Config.backend with - | FStar | Coq | HOL4 -> if not keep_fwd then "_fwd_back" else "_back" - | Lean -> if not keep_fwd then "" else "_back" + if not keep_fwd then "" else "_back" else if (* Several region groups/backward functions: - if all the regions in the group have names, we use those names diff --git a/compiler/ExtractBuiltin.ml b/compiler/ExtractBuiltin.ml new file mode 100644 index 00000000..a54ab604 --- /dev/null +++ b/compiler/ExtractBuiltin.ml @@ -0,0 +1,648 @@ +(** This file declares external identifiers that we catch to map them to + definitions coming from the standard libraries in our backends. + + TODO: there misses trait **implementations** + *) + +open Names +open Config + +type simple_name = string list [@@deriving show, ord] + +let name_to_simple_name (s : name) : simple_name = + (* We simply ignore the disambiguators *) + List.filter_map (function Ident id -> Some id | Disambiguator _ -> None) s + +(** Small helper which cuts a string at the occurrences of "::" *) +let string_to_simple_name (s : string) : simple_name = + (* No function to split by using string separator?? *) + let name = String.split_on_char ':' s in + List.filter (fun s -> s <> "") name + +module SimpleNameOrd = struct + type t = simple_name + + let compare = compare_simple_name + let to_string = show_simple_name + let pp_t = pp_simple_name + let show_t = show_simple_name +end + +module SimpleNameMap = Collections.MakeMap (SimpleNameOrd) +module SimpleNameSet = Collections.MakeSet (SimpleNameOrd) + +(** Small utility to memoize some computations *) +let mk_memoized (f : unit -> 'a) : unit -> 'a = + let r = ref None in + let g () = + match !r with + | Some x -> x + | None -> + let x = f () in + r := Some x; + x + in + g + +(** Switch between two values depending on the target backend. + + We often compute the same value (typically: a name) if the target + is F*, Coq or HOL4, and a different value if the target is Lean. + *) +let backend_choice (fstar_coq_hol4 : 'a) (lean : 'a) : 'a = + match !backend with Coq | FStar | HOL4 -> fstar_coq_hol4 | Lean -> lean + +let builtin_globals : (string * string) list = + [ + (* Min *) + ("core::num::usize::MIN", "core_usize_min"); + ("core::num::u8::MIN", "core_u8_min"); + ("core::num::u16::MIN", "core_u16_min"); + ("core::num::u32::MIN", "core_u32_min"); + ("core::num::u64::MIN", "core_u64_min"); + ("core::num::u128::MIN", "core_u128_min"); + ("core::num::isize::MIN", "core_isize_min"); + ("core::num::i8::MIN", "core_i8_min"); + ("core::num::i16::MIN", "core_i16_min"); + ("core::num::i32::MIN", "core_i32_min"); + ("core::num::i64::MIN", "core_i64_min"); + ("core::num::i128::MIN", "core_i128_min"); + (* Max *) + ("core::num::usize::MAX", "core_usize_max"); + ("core::num::u8::MAX", "core_u8_max"); + ("core::num::u16::MAX", "core_u16_max"); + ("core::num::u32::MAX", "core_u32_max"); + ("core::num::u64::MAX", "core_u64_max"); + ("core::num::u128::MAX", "core_u128_max"); + ("core::num::isize::MAX", "core_isize_max"); + ("core::num::i8::MAX", "core_i8_max"); + ("core::num::i16::MAX", "core_i16_max"); + ("core::num::i32::MAX", "core_i32_max"); + ("core::num::i64::MAX", "core_i64_max"); + ("core::num::i128::MAX", "core_i128_max"); + ] + +let builtin_globals_map : string SimpleNameMap.t = + SimpleNameMap.of_list + (List.map (fun (x, y) -> (string_to_simple_name x, y)) builtin_globals) + +type builtin_variant_info = { fields : (string * string) list } +[@@deriving show] + +type builtin_enum_variant_info = { + rust_variant_name : string; + extract_variant_name : string; + fields : string list option; +} +[@@deriving show] + +type builtin_type_body_info = + | Struct of string * (string * string) list + (* The constructor name and the map for the field names *) + | Enum of builtin_enum_variant_info list +(* For every variant, a map for the field names *) +[@@deriving show] + +type builtin_type_info = { + rust_name : string list; + extract_name : string; + keep_params : bool list option; + (** We might want to filter some of the type parameters. + + For instance, `Vec` type takes a type parameter for the allocator, + which we want to ignore. + *) + body_info : builtin_type_body_info option; +} +[@@deriving show] + +type type_variant_kind = + | KOpaque + | KStruct of (string * string) list + (* TODO: handle the tuple case *) + | KEnum (* TODO *) + +let mk_struct_constructor (type_name : string) : string = + let prefix = + match !backend with FStar -> "Mk" | Coq | HOL4 -> "mk" | Lean -> "" + in + let suffix = match !backend with FStar | Coq | HOL4 -> "" | Lean -> ".mk" in + prefix ^ type_name ^ suffix + +(** The assumed types. + + The optional list of booleans is filtering information for the type + parameters. For instance, in the case of the `Vec` functions, there is + a type parameter for the allocator to use, which we want to filter. + *) +let builtin_types () : builtin_type_info list = + let mk_type (rust_name : string list) ?(keep_params : bool list option = None) + ?(kind : type_variant_kind = KOpaque) () : builtin_type_info = + let extract_name = + let sep = backend_choice "_" "." in + String.concat sep rust_name + in + let body_info : builtin_type_body_info option = + match kind with + | KOpaque -> None + | KStruct fields -> + let fields = + List.map + (fun (rname, name) -> + ( rname, + match !backend with + | FStar | Lean -> name + | Coq | HOL4 -> extract_name ^ "_" ^ name )) + fields + in + let constructor = mk_struct_constructor extract_name in + Some (Struct (constructor, fields)) + | KEnum -> raise (Failure "TODO") + in + { rust_name; extract_name; keep_params; body_info } + in + + [ + (* Alloc *) + mk_type [ "alloc"; "alloc"; "Global" ] (); + (* Vec *) + mk_type [ "alloc"; "vec"; "Vec" ] ~keep_params:(Some [ true; false ]) (); + (* Range *) + mk_type + [ "core"; "ops"; "range"; "Range" ] + ~kind:(KStruct [ ("start", "start"); ("end", "end_") ]) + (); + (* Option + + This one is more custom because we use the standard "option" type from + the target backend. + *) + { + rust_name = [ "core"; "option"; "Option" ]; + extract_name = + (match !backend with + | Lean -> "Option" + | Coq | FStar | HOL4 -> "option"); + keep_params = None; + body_info = + Some + (Enum + [ + { + rust_variant_name = "None"; + extract_variant_name = + (match !backend with + | FStar | Coq -> "None" + | Lean -> "none" + | HOL4 -> "NONE"); + fields = None; + }; + { + rust_variant_name = "Some"; + extract_variant_name = + (match !backend with + | FStar | Coq -> "Some" + | Lean -> "some" + | HOL4 -> "SOME"); + fields = None; + }; + ]); + }; + ] + +let mk_builtin_types_map () = + SimpleNameMap.of_list + (List.map (fun info -> (info.rust_name, info)) (builtin_types ())) + +let builtin_types_map = mk_memoized mk_builtin_types_map + +type builtin_fun_info = { + rg : Types.RegionGroupId.id option; + extract_name : string; +} +[@@deriving show] + +(** The assumed functions. + + The optional list of booleans is filtering information for the type + parameters. For instance, in the case of the `Vec` functions, there is + a type parameter for the allocator to use, which we want to filter. + *) +let builtin_funs () : + (string list * bool list option * builtin_fun_info list) list = + let rg0 = Some Types.RegionGroupId.zero in + (* Small utility *) + let mk_fun (name : string list) (extract_name : string list option) + (filter : bool list option) (with_back : bool) (back_no_suffix : bool) : + string list * bool list option * builtin_fun_info list = + let extract_name = + match extract_name with None -> name | Some name -> name + in + let basename = + match !backend with + | FStar | Coq | HOL4 -> String.concat "_" extract_name + | Lean -> String.concat "." extract_name + in + let fwd_suffix = if with_back && back_no_suffix then "_fwd" else "" in + let fwd = [ { rg = None; extract_name = basename ^ fwd_suffix } ] in + let back_suffix = if with_back && back_no_suffix then "" else "_back" in + let back = + if with_back then [ { rg = rg0; extract_name = basename ^ back_suffix } ] + else [] + in + (name, filter, fwd @ back) + in + [ + mk_fun [ "core"; "mem"; "replace" ] None None true false; + mk_fun [ "alloc"; "vec"; "Vec"; "new" ] None None false false; + mk_fun + [ "alloc"; "vec"; "Vec"; "push" ] + None + (Some [ true; false ]) + true true; + mk_fun + [ "alloc"; "vec"; "Vec"; "insert" ] + None + (Some [ true; false ]) + true true; + mk_fun + [ "alloc"; "vec"; "Vec"; "len" ] + None + (Some [ true; false ]) + true false; + mk_fun + [ "alloc"; "vec"; "Vec"; "index" ] + None + (Some [ true; true; false ]) + true false; + mk_fun + [ "alloc"; "vec"; "Vec"; "index_mut" ] + None + (Some [ true; true; false ]) + true false; + mk_fun + [ "alloc"; "boxed"; "Box"; "deref" ] + None + (Some [ true; false ]) + true false; + mk_fun + [ "alloc"; "boxed"; "Box"; "deref_mut" ] + None + (Some [ true; false ]) + true false; + (* TODO: fix the same like "[T]" below *) + mk_fun + [ "core"; "slice"; "index"; "[T]"; "index" ] + (Some [ "core"; "slice"; "index"; "Slice"; "index" ]) + None true false; + mk_fun + [ "core"; "slice"; "index"; "[T]"; "index_mut" ] + (Some [ "core"; "slice"; "index"; "Slice"; "index_mut" ]) + None true false; + mk_fun + [ "core"; "array"; "[T; N]"; "index" ] + (Some [ "core"; "array"; "Array"; "index" ]) + None true false; + mk_fun + [ "core"; "array"; "[T; N]"; "index_mut" ] + (Some [ "core"; "array"; "Array"; "index_mut" ]) + None true false; + mk_fun [ "core"; "slice"; "index"; "Range"; "get" ] None None true false; + mk_fun [ "core"; "slice"; "index"; "Range"; "get_mut" ] None None true false; + mk_fun [ "core"; "slice"; "index"; "Range"; "index" ] None None true false; + mk_fun + [ "core"; "slice"; "index"; "Range"; "index_mut" ] + None None true false; + mk_fun + [ "core"; "slice"; "index"; "Range"; "get_unchecked" ] + None None false false; + mk_fun + [ "core"; "slice"; "index"; "Range"; "get_unchecked_mut" ] + None None false false; + mk_fun + [ "core"; "slice"; "index"; "usize"; "get" ] + (Some [ "core"; "slice"; "index"; "Usize"; "get" ]) + None true false; + mk_fun + [ "core"; "slice"; "index"; "usize"; "get_mut" ] + (Some [ "core"; "slice"; "index"; "Usize"; "get_mut" ]) + None true false; + mk_fun + [ "core"; "slice"; "index"; "usize"; "get_unchecked" ] + (Some [ "core"; "slice"; "index"; "Usize"; "get_unchecked" ]) + None false false; + mk_fun + [ "core"; "slice"; "index"; "usize"; "get_unchecked_mut" ] + (Some [ "core"; "slice"; "index"; "Usize"; "get_unchecked_mut" ]) + None false false; + mk_fun + [ "core"; "slice"; "index"; "usize"; "index" ] + (Some [ "core"; "slice"; "index"; "Usize"; "index" ]) + None true false; + mk_fun + [ "core"; "slice"; "index"; "usize"; "index_mut" ] + (Some [ "core"; "slice"; "index"; "Usize"; "index_mut" ]) + None true false; + ] + +let mk_builtin_funs_map () = + SimpleNameMap.of_list + (List.map + (fun (name, filter, info) -> (name, (filter, info))) + (builtin_funs ())) + +let builtin_funs_map = mk_memoized mk_builtin_funs_map + +type effect_info = { can_fail : bool; stateful : bool } + +let builtin_fun_effects = + let int_names = + [ + "usize"; + "u8"; + "u16"; + "u32"; + "u64"; + "u128"; + "isize"; + "i8"; + "i16"; + "i32"; + "i64"; + "i128"; + ] + in + let int_ops = + [ "wrapping_add"; "wrapping_sub"; "rotate_left"; "rotate_right" ] + in + let int_funs = + List.map + (fun int_name -> + List.map (fun op -> "core::num::" ^ int_name ^ "::" ^ op) int_ops) + int_names + in + let int_funs = List.concat int_funs in + let no_fail_no_state_funs = + [ + (* TODO: redundancy with the funs information below *) + "alloc::vec::Vec::new"; + "alloc::vec::Vec::len"; + "alloc::boxed::Box::deref"; + "alloc::boxed::Box::deref_mut"; + "core::mem::replace"; + "core::mem::take"; + ] + @ int_funs + in + let no_fail_no_state_funs = + List.map + (fun n -> (n, { can_fail = false; stateful = false })) + no_fail_no_state_funs + in + let no_state_funs = + [ + (* TODO: redundancy with the funs information below *) + "alloc::vec::Vec::push"; + "alloc::vec::Vec::index"; + "alloc::vec::Vec::index_mut"; + "alloc::vec::Vec::index_mut_back"; + ] + in + let no_state_funs = + List.map (fun n -> (n, { can_fail = true; stateful = false })) no_state_funs + in + no_fail_no_state_funs @ no_state_funs + +let builtin_fun_effects_map = + SimpleNameMap.of_list + (List.map (fun (n, x) -> (string_to_simple_name n, x)) builtin_fun_effects) + +type builtin_trait_decl_info = { + rust_name : string; + extract_name : string; + constructor : string; + parent_clauses : string list; + consts : (string * string) list; + types : (string * (string * string list)) list; + (** Every type has: + - a Rust name + - an extraction name + - a list of clauses *) + methods : (string * builtin_fun_info list) list; +} +[@@deriving show] + +let builtin_trait_decls_info () = + let rg0 = Some Types.RegionGroupId.zero in + let mk_trait (rust_name : string list) ?(extract_name : string option = None) + ?(parent_clauses : string list = []) ?(types : string list = []) + ?(methods : (string * bool) list = []) () : builtin_trait_decl_info = + let extract_name = + match extract_name with + | Some n -> n + | None -> ( + match !backend with + | Coq | FStar | HOL4 -> String.concat "_" rust_name + | Lean -> String.concat "." rust_name) + in + let constructor = mk_struct_constructor extract_name in + let consts = [] in + let types = + let mk_type item_name = + let type_name = + match !backend with + | Coq | FStar | HOL4 -> extract_name ^ "_" ^ item_name + | Lean -> item_name + in + let clauses = [] in + (item_name, (type_name, clauses)) + in + List.map mk_type types + in + let methods = + let mk_method (item_name, with_back) = + (* TODO: factor out with builtin_funs_info *) + let basename = + match !backend with + | Coq | FStar | HOL4 -> extract_name ^ "_" ^ item_name + | Lean -> item_name + in + let back_no_suffix = false in + let fwd_suffix = if with_back && back_no_suffix then "_fwd" else "" in + let fwd = [ { rg = None; extract_name = basename ^ fwd_suffix } ] in + let back_suffix = if with_back && back_no_suffix then "" else "_back" in + let back = + if with_back then + [ { rg = rg0; extract_name = basename ^ back_suffix } ] + else [] + in + (item_name, fwd @ back) + in + List.map mk_method methods + in + let rust_name = String.concat "::" rust_name in + { + rust_name; + extract_name; + constructor; + parent_clauses; + consts; + types; + methods; + } + in + [ + (* Deref *) + mk_trait + [ "core"; "ops"; "deref"; "Deref" ] + ~types:[ "Target" ] + ~methods:[ ("deref", true) ] + (); + (* DerefMut *) + mk_trait + [ "core"; "ops"; "deref"; "DerefMut" ] + ~parent_clauses:[ backend_choice "deref_inst" "derefInst" ] + ~methods:[ ("deref_mut", true) ] + (); + (* Index *) + mk_trait + [ "core"; "ops"; "index"; "Index" ] + ~types:[ "Output" ] + ~methods:[ ("index", true) ] + (); + (* IndexMut *) + mk_trait + [ "core"; "ops"; "index"; "IndexMut" ] + ~parent_clauses:[ backend_choice "index_inst" "indexInst" ] + ~methods:[ ("index_mut", true) ] + (); + (* Sealed *) + mk_trait [ "core"; "slice"; "index"; "private_slice_index"; "Sealed" ] (); + (* SliceIndex *) + mk_trait + [ "core"; "slice"; "index"; "SliceIndex" ] + ~parent_clauses:[ backend_choice "sealed_inst" "sealedInst" ] + ~types:[ "Output" ] + ~methods: + [ + ("get", true); + ("get_mut", true); + ("get_unchecked", false); + ("get_unchecked_mut", false); + ("index", true); + ("index_mut", true); + ] + (); + ] + +let mk_builtin_trait_decls_map () = + SimpleNameMap.of_list + (List.map + (fun info -> (string_to_simple_name info.rust_name, info)) + (builtin_trait_decls_info ())) + +let builtin_trait_decls_map = mk_memoized mk_builtin_trait_decls_map + +(* TODO: generalize this. + + For now, the key is: + - name of the impl (ex.: "alloc.boxed.Boxed") + - name of the implemented trait (ex.: "core.ops.deref.Deref" +*) +type simple_name_pair = simple_name * simple_name [@@deriving show, ord] + +module SimpleNamePairOrd = struct + type t = simple_name_pair + + let compare = compare_simple_name_pair + let to_string = show_simple_name_pair + let pp_t = pp_simple_name_pair + let show_t = show_simple_name_pair +end + +module SimpleNamePairMap = Collections.MakeMap (SimpleNamePairOrd) + +let builtin_trait_impls_info () : + ((string list * string list) * (bool list option * string)) list = + let fmt (type_name : string list) + ?(extract_type_name : string list option = None) + (trait_name : string list) ?(filter : bool list option = None) () : + (string list * string list) * (bool list option * string) = + let name = + let trait_name = String.concat "" trait_name ^ "Inst" in + let sep = backend_choice "_" "." in + let type_name = + match extract_type_name with + | Some type_name -> type_name + | None -> type_name + in + String.concat sep type_name ^ sep ^ trait_name + in + ((type_name, trait_name), (filter, name)) + in + (* TODO: fix the names like "[T]" below *) + [ + (* core::ops::Deref<alloc::boxed::Box<T>> *) + fmt [ "alloc"; "boxed"; "Box" ] [ "core"; "ops"; "deref"; "Deref" ] (); + (* core::ops::DerefMut<alloc::boxed::Box<T>> *) + fmt [ "alloc"; "boxed"; "Box" ] [ "core"; "ops"; "deref"; "DerefMut" ] (); + (* core::ops::index::Index<[T], I> *) + fmt + [ "core"; "slice"; "index"; "[T]" ] + ~extract_type_name:(Some [ "core"; "slice"; "index"; "Slice" ]) + [ "core"; "ops"; "index"; "Index" ] + (); + (* core::ops::index::IndexMut<[T], I> *) + fmt + [ "core"; "slice"; "index"; "[T]" ] + ~extract_type_name:(Some [ "core"; "slice"; "index"; "Slice" ]) + [ "core"; "ops"; "index"; "IndexMut" ] + (); + (* core::slice::index::private_slice_index::Sealed<Range<usize>> *) + fmt + [ "core"; "slice"; "index"; "private_slice_index"; "Range" ] + [ "core"; "slice"; "index"; "private_slice_index"; "Sealed" ] + (); + (* core::slice::index::SliceIndex<Range<usize>, [T]> *) + fmt + [ "core"; "slice"; "index"; "Range" ] + [ "core"; "slice"; "index"; "SliceIndex" ] + (); + (* core::ops::index::Index<[T; N], I> *) + fmt + [ "core"; "array"; "[T; N]" ] + ~extract_type_name:(Some [ "core"; "array"; "Array" ]) + [ "core"; "ops"; "index"; "Index" ] + (); + (* core::ops::index::IndexMut<[T; N], I> *) + fmt + [ "core"; "array"; "[T; N]" ] + ~extract_type_name:(Some [ "core"; "array"; "Array" ]) + [ "core"; "ops"; "index"; "IndexMut" ] + (); + (* core::slice::index::private_slice_index::Sealed<usize> *) + fmt + [ "core"; "slice"; "index"; "private_slice_index"; "usize" ] + [ "core"; "slice"; "index"; "private_slice_index"; "Sealed" ] + (); + (* core::slice::index::SliceIndex<usize, [T]> *) + fmt + [ "core"; "slice"; "index"; "usize" ] + [ "core"; "slice"; "index"; "SliceIndex" ] + (); + (* core::ops::index::Index<Vec<T>, T> *) + fmt [ "alloc"; "vec"; "Vec" ] + [ "core"; "ops"; "index"; "Index" ] + ~filter:(Some [ true; true; false ]) + (); + (* core::ops::index::IndexMut<Vec<T>, T> *) + fmt [ "alloc"; "vec"; "Vec" ] + [ "core"; "ops"; "index"; "IndexMut" ] + ~filter:(Some [ true; true; false ]) + (); + ] + +let mk_builtin_trait_impls_map () = + SimpleNamePairMap.of_list (builtin_trait_impls_info ()) + +let builtin_trait_impls_map = mk_memoized mk_builtin_trait_impls_map diff --git a/compiler/ExtractTypes.ml b/compiler/ExtractTypes.ml new file mode 100644 index 00000000..77f76bb4 --- /dev/null +++ b/compiler/ExtractTypes.ml @@ -0,0 +1,2477 @@ +(** The generic extraction *) +(* Turn the whole module into a functor: it is very annoying to carry the + the formatter everywhere... +*) + +open Pure +open PureUtils +open TranslateCore +open ExtractBase +open StringUtils +open Config +module F = Format + +(** Small helper to compute the name of an int type *) +let int_name (int_ty : integer_type) = + let isize, usize, i_format, u_format = + match !backend with + | FStar | Coq | HOL4 -> + ("isize", "usize", format_of_string "i%d", format_of_string "u%d") + | Lean -> ("Isize", "Usize", format_of_string "I%d", format_of_string "U%d") + in + match int_ty with + | Isize -> isize + | I8 -> Printf.sprintf i_format 8 + | I16 -> Printf.sprintf i_format 16 + | I32 -> Printf.sprintf i_format 32 + | I64 -> Printf.sprintf i_format 64 + | I128 -> Printf.sprintf i_format 128 + | Usize -> usize + | U8 -> Printf.sprintf u_format 8 + | U16 -> Printf.sprintf u_format 16 + | U32 -> Printf.sprintf u_format 32 + | U64 -> Printf.sprintf u_format 64 + | U128 -> Printf.sprintf u_format 128 + +(** Small helper to compute the name of a unary operation *) +let unop_name (unop : unop) : string = + match unop with + | Not -> ( + match !backend with FStar | Lean -> "not" | Coq -> "negb" | HOL4 -> "~") + | Neg (int_ty : integer_type) -> ( + match !backend with Lean -> "-" | _ -> int_name int_ty ^ "_neg") + | Cast _ -> + (* We never directly use the unop name in this case *) + raise (Failure "Unsupported") + +(** Small helper to compute the name of a binary operation (note that many + binary operations like "less than" are extracted to primitive operations, + like [<]). + *) +let named_binop_name (binop : E.binop) (int_ty : integer_type) : string = + let binop = + match binop with + | Div -> "div" + | Rem -> "rem" + | Add -> "add" + | Sub -> "sub" + | Mul -> "mul" + | Lt -> "lt" + | Le -> "le" + | Ge -> "ge" + | Gt -> "gt" + | BitXor -> "xor" + | BitAnd -> "and" + | BitOr -> "or" + | Shl -> "lsl" + | Shr -> + "asr" + (* NOTE: make sure arithmetic shift right is implemented, i.e. OCaml's asr operator, not lsr *) + | _ -> raise (Failure "Unreachable") + in + (* Remark: the Lean case is actually not used *) + match !backend with + | Lean -> int_name int_ty ^ "." ^ binop + | FStar | Coq | HOL4 -> int_name int_ty ^ "_" ^ binop + +(** A list of keywords/identifiers used by the backend and with which we + want to check collision. + + Remark: this is useful mostly to look for collisions when generating + names for *variables*. + *) +let keywords () = + let named_unops = + unop_name Not + :: List.map (fun it -> unop_name (Neg it)) T.all_signed_int_types + in + let named_binops = [ E.Div; Rem; Add; Sub; Mul ] in + let named_binops = + List.concat_map + (fun bn -> List.map (fun it -> named_binop_name bn it) T.all_int_types) + named_binops + in + let misc = + match !backend with + | FStar -> + [ + "assert"; + "assert_norm"; + "assume"; + "else"; + "fun"; + "fn"; + "FStar"; + "FStar.Mul"; + "if"; + "in"; + "include"; + "int"; + "let"; + "list"; + "match"; + "open"; + "rec"; + "scalar_cast"; + "then"; + "type"; + "Type0"; + "Type"; + "unit"; + "val"; + "with"; + ] + | Coq -> + [ + "assert"; + "Arguments"; + "Axiom"; + "char_of_byte"; + "Check"; + "Declare"; + "Definition"; + "else"; + "End"; + "fun"; + "Fixpoint"; + "if"; + "in"; + "int"; + "Inductive"; + "Import"; + "let"; + "Lemma"; + "match"; + "Module"; + "not"; + "Notation"; + "Proof"; + "Qed"; + "rec"; + "Record"; + "Require"; + "Scope"; + "Search"; + "SearchPattern"; + "Set"; + "then"; + (* [tt] is unit *) + "tt"; + "type"; + "Type"; + "unit"; + "with"; + ] + | Lean -> + [ + "by"; + "class"; + "decreasing_by"; + "def"; + "deriving"; + "do"; + "else"; + "end"; + "for"; + "have"; + "if"; + "inductive"; + "instance"; + "import"; + "let"; + "macro"; + "match"; + "namespace"; + "opaque"; + "open"; + "run_cmd"; + "set_option"; + "simp"; + "structure"; + "syntax"; + "termination_by"; + "then"; + "Type"; + "unsafe"; + "where"; + "with"; + "opaque_defs"; + ] + | HOL4 -> + [ + "Axiom"; + "case"; + "Definition"; + "else"; + "End"; + "fix"; + "fix_exec"; + "fn"; + "fun"; + "if"; + "in"; + "int"; + "Inductive"; + "let"; + "of"; + "Proof"; + "QED"; + "then"; + "Theorem"; + ] + in + List.concat [ named_unops; named_binops; misc ] + +let assumed_adts () : (assumed_ty * string) list = + match !backend with + | Lean -> + [ + (State, "State"); + (Result, "Result"); + (Error, "Error"); + (Fuel, "Nat"); + (Array, "Array"); + (Slice, "Slice"); + (Str, "Str"); + (RawPtr Mut, "MutRawPtr"); + (RawPtr Const, "ConstRawPtr"); + ] + | Coq | FStar | HOL4 -> + [ + (State, "state"); + (Result, "result"); + (Error, "error"); + (Fuel, if !backend = HOL4 then "num" else "nat"); + (Array, "array"); + (Slice, "slice"); + (Str, "str"); + (RawPtr Mut, "mut_raw_ptr"); + (RawPtr Const, "const_raw_ptr"); + ] + +let assumed_struct_constructors () : (assumed_ty * string) list = + match !backend with + | Lean -> [ (Array, "Array.make") ] + | Coq -> [ (Array, "mk_array") ] + | FStar -> [ (Array, "mk_array") ] + | HOL4 -> [ (Array, "mk_array") ] + +let assumed_variants () : (assumed_ty * VariantId.id * string) list = + match !backend with + | FStar -> + [ + (Result, result_return_id, "Return"); + (Result, result_fail_id, "Fail"); + (Error, error_failure_id, "Failure"); + (Error, error_out_of_fuel_id, "OutOfFuel"); + (* No Fuel::Zero on purpose *) + (* No Fuel::Succ on purpose *) + ] + | Coq -> + [ + (Result, result_return_id, "Return"); + (Result, result_fail_id, "Fail_"); + (Error, error_failure_id, "Failure"); + (Error, error_out_of_fuel_id, "OutOfFuel"); + (Fuel, fuel_zero_id, "O"); + (Fuel, fuel_succ_id, "S"); + ] + | Lean -> + [ + (Result, result_return_id, "ret"); + (Result, result_fail_id, "fail"); + (Error, error_failure_id, "panic"); + (* No Fuel::Zero on purpose *) + (* No Fuel::Succ on purpose *) + ] + | HOL4 -> + [ + (Result, result_return_id, "Return"); + (Result, result_fail_id, "Fail"); + (Error, error_failure_id, "Failure"); + (* No Fuel::Zero on purpose *) + (* No Fuel::Succ on purpose *) + ] + +let assumed_llbc_functions () : + (A.assumed_fun_id * T.RegionGroupId.id option * string) list = + let rg0 = Some T.RegionGroupId.zero in + match !backend with + | FStar | Coq | HOL4 -> + [ + (ArrayIndexShared, None, "array_index_usize"); + (ArrayIndexMut, None, "array_index_usize"); + (ArrayIndexMut, rg0, "array_update_usize"); + (ArrayToSliceShared, None, "array_to_slice"); + (ArrayToSliceMut, None, "array_to_slice"); + (ArrayToSliceMut, rg0, "array_from_slice"); + (ArrayRepeat, None, "array_repeat"); + (SliceIndexShared, None, "slice_index_usize"); + (SliceIndexMut, None, "slice_index_usize"); + (SliceIndexMut, rg0, "slice_update_usize"); + (SliceLen, None, "slice_len"); + ] + | Lean -> + [ + (ArrayIndexShared, None, "Array.index_usize"); + (ArrayIndexMut, None, "Array.index_usize"); + (ArrayIndexMut, rg0, "Array.update_usize"); + (ArrayToSliceShared, None, "Array.to_slice"); + (ArrayToSliceMut, None, "Array.to_slice"); + (ArrayToSliceMut, rg0, "Array.from_slice"); + (ArrayRepeat, None, "Array.repeat"); + (SliceIndexShared, None, "Slice.index_usize"); + (SliceIndexMut, None, "Slice.index_usize"); + (SliceIndexMut, rg0, "Slice.update_usize"); + (SliceLen, None, "Slice.len"); + ] + +let assumed_pure_functions () : (pure_assumed_fun_id * string) list = + match !backend with + | FStar -> + [ + (Return, "return"); + (Fail, "fail"); + (Assert, "massert"); + (FuelDecrease, "decrease"); + (FuelEqZero, "is_zero"); + ] + | Coq -> + (* We don't provide [FuelDecrease] and [FuelEqZero] on purpose *) + [ (Return, "return_"); (Fail, "fail_"); (Assert, "massert") ] + | Lean -> + (* We don't provide [FuelDecrease] and [FuelEqZero] on purpose *) + [ (Return, "return"); (Fail, "fail_"); (Assert, "massert") ] + | HOL4 -> + (* We don't provide [FuelDecrease] and [FuelEqZero] on purpose *) + [ (Return, "return"); (Fail, "fail"); (Assert, "massert") ] + +let names_map_init () : names_map_init = + { + keywords = keywords (); + assumed_adts = assumed_adts (); + assumed_structs = assumed_struct_constructors (); + assumed_variants = assumed_variants (); + assumed_llbc_functions = assumed_llbc_functions (); + assumed_pure_functions = assumed_pure_functions (); + } + +let extract_unop (extract_expr : bool -> texpression -> unit) + (fmt : F.formatter) (inside : bool) (unop : unop) (arg : texpression) : unit + = + match unop with + | Not | Neg _ -> + let unop = unop_name unop in + if inside then F.pp_print_string fmt "("; + F.pp_print_string fmt unop; + F.pp_print_space fmt (); + extract_expr true arg; + if inside then F.pp_print_string fmt ")" + | Cast (src, tgt) -> ( + (* HOL4 has a special treatment: because it doesn't support dependent + types, we don't have a specific operator for the cast *) + match !backend with + | HOL4 -> + (* Casting, say, an u32 to an i32 would be done as follows: + {[ + mk_i32 (u32_to_int x) + ]} + *) + if inside then F.pp_print_string fmt "("; + F.pp_print_string fmt ("mk_" ^ int_name tgt); + F.pp_print_space fmt (); + F.pp_print_string fmt "("; + F.pp_print_string fmt (int_name src ^ "_to_int"); + F.pp_print_space fmt (); + extract_expr true arg; + F.pp_print_string fmt ")"; + if inside then F.pp_print_string fmt ")" + | FStar | Coq | Lean -> + (* Rem.: the source type is an implicit parameter *) + if inside then F.pp_print_string fmt "("; + let cast_str = + match !backend with + | Coq | FStar -> "scalar_cast" + | Lean -> (* TODO: I8.cast, I16.cast, etc.*) "Scalar.cast" + | HOL4 -> raise (Failure "Unreachable") + in + F.pp_print_string fmt cast_str; + F.pp_print_space fmt (); + if !backend <> Lean then ( + F.pp_print_string fmt + (StringUtils.capitalize_first_letter + (PrintPure.integer_type_to_string src)); + F.pp_print_space fmt ()); + if !backend = Lean then F.pp_print_string fmt ("." ^ int_name tgt) + else + F.pp_print_string fmt + (StringUtils.capitalize_first_letter + (PrintPure.integer_type_to_string tgt)); + F.pp_print_space fmt (); + extract_expr true arg; + if inside then F.pp_print_string fmt ")") + +(** [extract_expr] : the boolean argument is [inside] *) +let extract_binop (extract_expr : bool -> texpression -> unit) + (fmt : F.formatter) (inside : bool) (binop : E.binop) + (int_ty : integer_type) (arg0 : texpression) (arg1 : texpression) : unit = + if inside then F.pp_print_string fmt "("; + (* Some binary operations have a special notation depending on the backend *) + (match (!backend, binop) with + | HOL4, (Eq | Ne) + | (FStar | Coq | Lean), (Eq | Lt | Le | Ne | Ge | Gt) + | Lean, (Div | Rem | Add | Sub | Mul) -> + let binop = + match binop with + | Eq -> "=" + | Lt -> "<" + | Le -> "<=" + | Ne -> if !backend = Lean then "!=" else "<>" + | Ge -> ">=" + | Gt -> ">" + | Div -> "/" + | Rem -> "%" + | Add -> "+" + | Sub -> "-" + | Mul -> "*" + | _ -> raise (Failure "Unreachable") + in + let binop = + match !backend with FStar | Lean | HOL4 -> binop | Coq -> "s" ^ binop + in + extract_expr false arg0; + F.pp_print_space fmt (); + F.pp_print_string fmt binop; + F.pp_print_space fmt (); + extract_expr false arg1 + | _ -> + let binop = named_binop_name binop int_ty in + F.pp_print_string fmt binop; + F.pp_print_space fmt (); + extract_expr true arg0; + F.pp_print_space fmt (); + extract_expr true arg1); + if inside then F.pp_print_string fmt ")" + +let type_decl_kind_to_qualif (kind : decl_kind) + (type_kind : type_decl_kind option) : string option = + match !backend with + | FStar -> ( + match kind with + | SingleNonRec -> Some "type" + | SingleRec -> Some "type" + | MutRecFirst -> Some "type" + | MutRecInner -> Some "and" + | MutRecLast -> Some "and" + | Assumed -> Some "assume type" + | Declared -> Some "val") + | Coq -> ( + match (kind, type_kind) with + | SingleNonRec, Some Enum -> Some "Inductive" + | SingleNonRec, Some Struct -> Some "Record" + | (SingleRec | MutRecFirst), Some _ -> Some "Inductive" + | (MutRecInner | MutRecLast), Some _ -> + (* Coq doesn't support groups of mutually recursive definitions which mix + * records and inducties: we convert everything to records if this happens + *) + Some "with" + | (Assumed | Declared), None -> Some "Axiom" + | SingleNonRec, None -> + (* This is for traits *) + Some "Record" + | _ -> + raise + (Failure + ("Unexpected: (" ^ show_decl_kind kind ^ ", " + ^ Print.option_to_string show_type_decl_kind type_kind + ^ ")"))) + | Lean -> ( + match kind with + | SingleNonRec -> + if type_kind = Some Struct then Some "structure" else Some "inductive" + | SingleRec -> Some "inductive" + | MutRecFirst -> Some "inductive" + | MutRecInner -> Some "inductive" + | MutRecLast -> Some "inductive" + | Assumed -> Some "axiom" + | Declared -> Some "axiom") + | HOL4 -> None + +let fun_decl_kind_to_qualif (kind : decl_kind) : string option = + match !backend with + | FStar -> ( + match kind with + | SingleNonRec -> Some "let" + | SingleRec -> Some "let rec" + | MutRecFirst -> Some "let rec" + | MutRecInner -> Some "and" + | MutRecLast -> Some "and" + | Assumed -> Some "assume val" + | Declared -> Some "val") + | Coq -> ( + match kind with + | SingleNonRec -> Some "Definition" + | SingleRec -> Some "Fixpoint" + | MutRecFirst -> Some "Fixpoint" + | MutRecInner -> Some "with" + | MutRecLast -> Some "with" + | Assumed -> Some "Axiom" + | Declared -> Some "Axiom") + | Lean -> ( + match kind with + | SingleNonRec -> Some "def" + | SingleRec -> Some "divergent def" + | MutRecFirst -> Some "mutual divergent def" + | MutRecInner -> Some "divergent def" + | MutRecLast -> Some "divergent def" + | Assumed -> Some "axiom" + | Declared -> Some "axiom") + | HOL4 -> None + +(** The type of types. + + TODO: move inside the formatter? + *) +let type_keyword () = + match !backend with + | FStar -> "Type0" + | Coq | Lean -> "Type" + | HOL4 -> raise (Failure "Unexpected") + +(** + [ctx]: we use the context to lookup type definitions, to retrieve type names. + This is used to compute variable names, when they have no basenames: in this + case we use the first letter of the type name. + + [variant_concatenate_type_name]: if true, add the type name as a prefix + to the variant names. + Ex.: + In Rust: + {[ + enum List = { + Cons(u32, Box<List>),x + Nil, + } + ]} + + F*, if option activated: + {[ + type list = + | ListCons : u32 -> list -> list + | ListNil : list + ]} + + F*, if option not activated: + {[ + type list = + | Cons : u32 -> list -> list + | Nil : list + ]} + + Rk.: this should be true by default, because in Rust all the variant names + are actively uniquely identifier by the type name [List::Cons(...)], while + in other languages it is not necessarily the case, and thus clashes can mess + up type checking. Note that some languages actually forbids the name clashes + (it is the case of F* ). + *) +let mk_formatter (ctx : trans_ctx) (crate_name : string) + (variant_concatenate_type_name : bool) : formatter = + let int_name = int_name in + + (* Prepare a name. + * The first id elem is always the crate: if it is the local crate, + * we remove it. + * We also remove all the disambiguators, then convert everything to strings. + * **Rmk:** because we remove the disambiguators, there may be name collisions + * (which is ok, because we check for name collisions and fail if there is any). + *) + let get_name (name : name) : string list = + (* Rmk.: initially we only filtered the disambiguators equal to 0 *) + let name = Names.filter_disambiguators name in + match name with + | Ident crate :: name -> + let name = if crate = crate_name then name else Ident crate :: name in + let name = + List.map + (function + | Names.Ident s -> s + | Disambiguator d -> Names.Disambiguator.to_string d) + name + in + name + | _ -> + raise (Failure ("Unexpected name shape: " ^ Print.name_to_string name)) + in + let flatten_name (name : string list) : string = + match !backend with + | FStar | Coq | HOL4 -> String.concat "_" name + | Lean -> String.concat "." name + in + let get_type_name = get_name in + let get_type_name_no_suffix name = + match !backend with + | FStar | Coq | HOL4 -> String.concat "_" (get_type_name name) + | Lean -> String.concat "." (get_type_name name) + in + let type_name name = + match !backend with + | FStar -> + StringUtils.lowercase_first_letter (get_type_name_no_suffix name ^ "_t") + | Coq | HOL4 -> get_type_name_no_suffix name ^ "_t" + | Lean -> get_type_name_no_suffix name + in + let field_name (def_name : name) (field_id : FieldId.id) + (field_name : string option) : string = + let field_name_s = + match field_name with + | Some field_name -> field_name + | None -> + (* TODO: extract structs with no field names to tuples *) + FieldId.to_string field_id + in + if !Config.record_fields_short_names then + if field_name = None then (* TODO: this is a bit ugly *) + "_" ^ field_name_s + else field_name_s + else + let def_name = get_type_name_no_suffix def_name ^ "_" ^ field_name_s in + match !backend with + | Lean | HOL4 -> def_name + | Coq | FStar -> StringUtils.lowercase_first_letter def_name + in + let variant_name (def_name : name) (variant : string) : string = + match !backend with + | FStar | Coq | HOL4 -> + let variant = to_camel_case variant in + if variant_concatenate_type_name then + StringUtils.capitalize_first_letter + (get_type_name_no_suffix def_name ^ "_" ^ variant) + else variant + | Lean -> variant + in + let struct_constructor (basename : name) : string = + let tname = type_name basename in + ExtractBuiltin.mk_struct_constructor tname + in + let get_fun_name fname = + let fname = get_name fname in + (* TODO: don't convert to snake case for Coq, HOL4, F* *) + let fname = flatten_name fname in + match !backend with + | FStar | Coq | HOL4 -> StringUtils.lowercase_first_letter fname + | Lean -> fname + in + let global_name (name : global_name) : string = + (* Converting to snake case also lowercases the letters (in Rust, global + * names are written in capital letters). *) + let parts = List.map to_snake_case (get_name name) in + String.concat "_" parts + in + let fun_name (fname : fun_name) (num_loops : int) (loop_id : LoopId.id option) + (num_rgs : int) (rg : region_group_info option) (filter_info : bool * int) + : string = + let fname = get_fun_name fname in + (* Compute the suffix *) + let suffix = default_fun_suffix num_loops loop_id num_rgs rg filter_info in + (* Concatenate *) + fname ^ suffix + in + + let trait_decl_name (trait_decl : trait_decl) : string = + type_name trait_decl.name + in + + let trait_impl_name (trait_decl : trait_decl) (trait_impl : trait_impl) : + string = + (* TODO: provisional: we concatenate the trait impl name (which is its type) + with the trait decl name *) + let trait_decl = + let name = trait_decl.name in + let name = get_type_name_no_suffix name ^ "Inst" in + (* Remove the occurrences of '.' *) + String.concat "" (String.split_on_char '.' name) + in + let name = flatten_name (get_type_name trait_impl.name @ [ trait_decl ]) in + match !backend with + | FStar -> StringUtils.lowercase_first_letter name + | Coq | HOL4 | Lean -> name + in + + let trait_decl_constructor (trait_decl : trait_decl) : string = + let name = trait_decl_name trait_decl in + ExtractBuiltin.mk_struct_constructor name + in + + let trait_parent_clause_name (trait_decl : trait_decl) (clause : trait_clause) + : string = + (* TODO: improve - it would be better to not use indices *) + let clause = "parent_clause_" ^ TraitClauseId.to_string clause.clause_id in + if !Config.record_fields_short_names then clause + else trait_decl_name trait_decl ^ "_" ^ clause + in + let trait_type_name (trait_decl : trait_decl) (item : string) : string = + let name = + if !Config.record_fields_short_names then item + else trait_decl_name trait_decl ^ "_" ^ item + in + (* Constants are usually all capital letters. + Some backends do not support field names starting with a capital letter, + and it may be weird to lowercase everything (especially as it may lead + to more name collisions): we add a prefix when necessary. + For instance, it gives: "U" -> "tU" + Note that for some backends we prepend the type name (because those backends + can't disambiguate fields coming from different ADTs if they have the same + names), and thus don't need to add a prefix starting with a lowercase. + *) + match !backend with FStar -> "t" ^ name | Coq | Lean | HOL4 -> name + in + let trait_const_name (trait_decl : trait_decl) (item : string) : string = + let name = + if !Config.record_fields_short_names then item + else trait_decl_name trait_decl ^ "_" ^ item + in + (* See [trait_type_name] *) + match !backend with FStar -> "c" ^ name | Coq | Lean | HOL4 -> name + in + let trait_method_name (trait_decl : trait_decl) (item : string) : string = + if !Config.record_fields_short_names then item + else trait_decl_name trait_decl ^ "_" ^ item + in + let trait_type_clause_name (trait_decl : trait_decl) (item : string) + (clause : trait_clause) : string = + (* TODO: improve - it would be better to not use indices *) + trait_type_name trait_decl item + ^ "_clause_" + ^ TraitClauseId.to_string clause.clause_id + in + + let termination_measure_name (_fid : A.FunDeclId.id) (fname : fun_name) + (num_loops : int) (loop_id : LoopId.id option) : string = + let fname = get_fun_name fname in + let lp_suffix = default_fun_loop_suffix num_loops loop_id in + (* Compute the suffix *) + let suffix = + match !Config.backend with + | FStar -> "_decreases" + | Lean -> "_terminates" + | Coq | HOL4 -> raise (Failure "Unexpected") + in + (* Concatenate *) + fname ^ lp_suffix ^ suffix + in + + let decreases_proof_name (_fid : A.FunDeclId.id) (fname : fun_name) + (num_loops : int) (loop_id : LoopId.id option) : string = + let fname = get_fun_name fname in + let lp_suffix = default_fun_loop_suffix num_loops loop_id in + (* Compute the suffix *) + let suffix = + match !Config.backend with + | Lean -> "_decreases" + | FStar | Coq | HOL4 -> raise (Failure "Unexpected") + in + (* Concatenate *) + fname ^ lp_suffix ^ suffix + in + + let var_basename (_varset : StringSet.t) (basename : string option) (ty : ty) + : string = + (* Small helper to derive var names from ADT type names. + + We do the following: + - convert the type name to snake case + - take the first letter of every "letter group" + Ex.: "HashMap" -> "hash_map" -> "hm" + *) + let name_from_type_ident (name : string) : string = + let cl = to_snake_case name in + let cl = String.split_on_char '_' cl in + let cl = List.filter (fun s -> String.length s > 0) cl in + assert (List.length cl > 0); + let cl = List.map (fun s -> s.[0]) cl in + StringUtils.string_of_chars cl + in + (* If there is a basename, we use it *) + match basename with + | Some basename -> + (* This should be a no-op *) + to_snake_case basename + | None -> ( + (* No basename: we use the first letter of the type *) + match ty with + | Adt (type_id, generics) -> ( + match type_id with + | Tuple -> + (* The "pair" case is frequent enough to have its special treatment *) + if List.length generics.types = 2 then "p" else "t" + | Assumed Result -> "r" + | Assumed Error -> ConstStrings.error_basename + | Assumed Fuel -> ConstStrings.fuel_basename + | Assumed Array -> "a" + | Assumed Slice -> "s" + | Assumed Str -> "s" + | Assumed State -> ConstStrings.state_basename + | Assumed (RawPtr _) -> "p" + | AdtId adt_id -> + let def = TypeDeclId.Map.find adt_id ctx.type_ctx.type_decls in + (* Derive the var name from the last ident of the type name + * Ex.: ["hashmap"; "HashMap"] ~~> "HashMap" -> "hash_map" -> "hm" + *) + (* The name shouldn't be empty, and its last element should + * be an ident *) + let cl = List.nth def.name (List.length def.name - 1) in + name_from_type_ident (Names.as_ident cl)) + | TypeVar _ -> ( + (* TODO: use "t" also for F* *) + match !backend with + | FStar -> "x" (* lacking inspiration here... *) + | Coq | Lean | HOL4 -> "t" (* lacking inspiration here... *)) + | Literal lty -> ( + match lty with Bool -> "b" | Char -> "c" | Integer _ -> "i") + | Arrow _ -> "f" + | TraitType (_, _, name) -> name_from_type_ident name) + in + let type_var_basename (_varset : StringSet.t) (basename : string) : string = + (* Rust type variables are snake-case and start with a capital letter *) + match !backend with + | FStar -> + (* This is *not* a no-op: this removes the capital letter *) + to_snake_case basename + | HOL4 -> + (* In HOL4, type variable names must start with "'" *) + "'" ^ to_snake_case basename + | Coq | Lean -> basename + in + let const_generic_var_basename (_varset : StringSet.t) (basename : string) : + string = + (* Rust type variables are snake-case and start with a capital letter *) + match !backend with + | FStar | HOL4 -> + (* This is *not* a no-op: this removes the capital letter *) + to_snake_case basename + | Coq | Lean -> basename + in + let trait_clause_basename (_varset : StringSet.t) (_clause : trait_clause) : + string = + (* TODO: actually use the clause to derive the name *) + "inst" + in + let trait_self_clause_basename = "self_clause" in + let append_index (basename : string) (i : int) : string = + basename ^ string_of_int i + in + + let extract_literal (fmt : F.formatter) (inside : bool) (cv : literal) : unit + = + match cv with + | Scalar sv -> ( + match !backend with + | FStar -> F.pp_print_string fmt (Z.to_string sv.PV.value) + | Coq | HOL4 | Lean -> + let print_brackets = inside && !backend = HOL4 in + if print_brackets then F.pp_print_string fmt "("; + (match !backend with + | Coq | Lean -> () + | HOL4 -> + F.pp_print_string fmt ("int_to_" ^ int_name sv.PV.int_ty); + F.pp_print_space fmt () + | _ -> raise (Failure "Unreachable")); + (* We need to add parentheses if the value is negative *) + if sv.PV.value >= Z.of_int 0 then + F.pp_print_string fmt (Z.to_string sv.PV.value) + else if !backend = Lean then + (* TODO: parsing issues with Lean because there are ambiguous + interpretations between int values and nat values *) + F.pp_print_string fmt + ("(-(" ^ Z.to_string (Z.neg sv.PV.value) ^ ":Int))") + else F.pp_print_string fmt ("(" ^ Z.to_string sv.PV.value ^ ")"); + (match !backend with + | Coq -> + let iname = int_name sv.PV.int_ty in + F.pp_print_string fmt ("%" ^ iname) + | Lean -> + let iname = String.lowercase_ascii (int_name sv.PV.int_ty) in + F.pp_print_string fmt ("#" ^ iname) + | HOL4 -> () + | _ -> raise (Failure "Unreachable")); + if print_brackets then F.pp_print_string fmt ")") + | Bool b -> + let b = + match !backend with + | HOL4 -> if b then "T" else "F" + | Coq | FStar | Lean -> if b then "true" else "false" + in + F.pp_print_string fmt b + | Char c -> ( + match !backend with + | HOL4 -> + (* [#"a"] is a notation for [CHR 97] (97 is the ASCII code for 'a') *) + F.pp_print_string fmt ("#\"" ^ String.make 1 c ^ "\"") + | FStar | Lean -> F.pp_print_string fmt ("'" ^ String.make 1 c ^ "'") + | Coq -> + if inside then F.pp_print_string fmt "("; + F.pp_print_string fmt "char_of_byte"; + F.pp_print_space fmt (); + (* Convert the the char to ascii *) + let c = + let i = Char.code c in + let x0 = i / 16 in + let x1 = i mod 16 in + "Coq.Init.Byte.x" ^ string_of_int x0 ^ string_of_int x1 + in + F.pp_print_string fmt c; + if inside then F.pp_print_string fmt ")") + in + let bool_name = if !backend = Lean then "Bool" else "bool" in + let char_name = if !backend = Lean then "Char" else "char" in + let str_name = if !backend = Lean then "String" else "string" in + { + bool_name; + char_name; + int_name; + str_name; + type_decl_kind_to_qualif; + fun_decl_kind_to_qualif; + field_name; + variant_name; + struct_constructor; + type_name; + global_name; + fun_name; + termination_measure_name; + decreases_proof_name; + trait_decl_name; + trait_impl_name; + trait_decl_constructor; + trait_parent_clause_name; + trait_const_name; + trait_type_name; + trait_method_name; + trait_type_clause_name; + var_basename; + type_var_basename; + const_generic_var_basename; + trait_self_clause_basename; + trait_clause_basename; + append_index; + extract_literal; + extract_unop; + extract_binop; + } + +let mk_formatter_and_names_maps (ctx : trans_ctx) (crate_name : string) + (variant_concatenate_type_name : bool) : formatter * names_maps = + let fmt = mk_formatter ctx crate_name variant_concatenate_type_name in + let names_maps = initialize_names_maps fmt (names_map_init ()) in + (fmt, names_maps) + +let is_single_opaque_fun_decl_group (dg : Pure.fun_decl list) : bool = + match dg with [ d ] -> d.body = None | _ -> false + +let is_single_opaque_type_decl_group (dg : Pure.type_decl list) : bool = + match dg with [ d ] -> d.kind = Opaque | _ -> false + +let is_empty_record_type_decl (d : Pure.type_decl) : bool = d.kind = Struct [] + +let is_empty_record_type_decl_group (dg : Pure.type_decl list) : bool = + match dg with [ d ] -> is_empty_record_type_decl d | _ -> false + +(** In some provers, groups of definitions must be delimited. + + - in Coq, *every* group (including singletons) must end with "." + - in Lean, groups of mutually recursive definitions must end with "end" + - in HOL4 (in most situations) the whole group must be within a `Define` command + + Calls to {!extract_fun_decl} should be inserted between calls to + {!start_fun_decl_group} and {!end_fun_decl_group}. + + TODO: maybe those [{start/end}_decl_group] functions are not that much a good + idea and we should merge them with the corresponding [extract_decl] functions. + *) +let start_fun_decl_group (ctx : extraction_ctx) (fmt : F.formatter) + (is_rec : bool) (dg : Pure.fun_decl list) = + match !backend with + | FStar | Coq | Lean -> () + | HOL4 -> + (* In HOL4, opaque functions have a special treatment *) + if is_single_opaque_fun_decl_group dg then () + else + let compute_fun_def_name (def : Pure.fun_decl) : string = + ctx_get_local_function def.def_id def.loop_id def.back_id ctx ^ "_def" + in + let names = List.map compute_fun_def_name dg in + (* Add a break before *) + F.pp_print_break fmt 0 0; + (* Open the box for the delimiters *) + F.pp_open_vbox fmt 0; + (* Open the box for the definitions themselves *) + F.pp_open_vbox fmt ctx.indent_incr; + (* Print the delimiters *) + if is_rec then + F.pp_print_string fmt + ("val [" ^ String.concat ", " names ^ "] = DefineDiv ‘") + else ( + assert (List.length names = 1); + let name = List.hd names in + F.pp_print_string fmt ("val " ^ name ^ " = Define ‘")); + F.pp_print_cut fmt () + +(** See {!start_fun_decl_group}. *) +let end_fun_decl_group (fmt : F.formatter) (is_rec : bool) + (dg : Pure.fun_decl list) = + match !backend with + | FStar -> () + | Coq -> + (* For aesthetic reasons, we print the Coq end group delimiter directly + in {!extract_fun_decl}. *) + () + | Lean -> + (* We must add the "end" keyword to groups of mutually recursive functions *) + if is_rec && List.length dg > 1 then ( + F.pp_print_cut fmt (); + F.pp_print_string fmt "end"; + (* Add breaks to insert new lines between definitions *) + F.pp_print_break fmt 0 0) + else () + | HOL4 -> + (* In HOL4, opaque functions have a special treatment *) + if is_single_opaque_fun_decl_group dg then () + else ( + (* Close the box for the definitions *) + F.pp_close_box fmt (); + (* Print the end delimiter *) + F.pp_print_cut fmt (); + F.pp_print_string fmt "’"; + (* Close the box for the delimiters *) + F.pp_close_box fmt (); + (* Add breaks to insert new lines between definitions *) + F.pp_print_break fmt 0 0) + +(** See {!start_fun_decl_group}: similar usage, but for the type declarations. *) +let start_type_decl_group (ctx : extraction_ctx) (fmt : F.formatter) + (is_rec : bool) (dg : Pure.type_decl list) = + match !backend with + | FStar | Coq -> () + | Lean -> + if is_rec && List.length dg > 1 then ( + F.pp_print_space fmt (); + F.pp_print_string fmt "mutual"; + F.pp_print_space fmt ()) + | HOL4 -> + (* In HOL4, opaque types and empty records have a special treatment *) + if + is_single_opaque_type_decl_group dg + || is_empty_record_type_decl_group dg + then () + else ( + (* Add a break before *) + F.pp_print_break fmt 0 0; + (* Open the box for the delimiters *) + F.pp_open_vbox fmt 0; + (* Open the box for the definitions themselves *) + F.pp_open_vbox fmt ctx.indent_incr; + (* Print the delimiters *) + F.pp_print_string fmt "Datatype:"; + F.pp_print_cut fmt ()) + +(** See {!start_fun_decl_group}. *) +let end_type_decl_group (fmt : F.formatter) (is_rec : bool) + (dg : Pure.type_decl list) = + match !backend with + | FStar -> () + | Coq -> + (* For aesthetic reasons, we print the Coq end group delimiter directly + in {!extract_fun_decl}. *) + () + | Lean -> + (* We must add the "end" keyword to groups of mutually recursive functions *) + if is_rec && List.length dg > 1 then ( + F.pp_print_cut fmt (); + F.pp_print_string fmt "end"; + (* Add breaks to insert new lines between definitions *) + F.pp_print_break fmt 0 0) + else () + | HOL4 -> + (* In HOL4, opaque types and empty records have a special treatment *) + if + is_single_opaque_type_decl_group dg + || is_empty_record_type_decl_group dg + then () + else ( + (* Close the box for the definitions *) + F.pp_close_box fmt (); + (* Print the end delimiter *) + F.pp_print_cut fmt (); + F.pp_print_string fmt "End"; + (* Close the box for the delimiters *) + F.pp_close_box fmt (); + (* Add breaks to insert new lines between definitions *) + F.pp_print_break fmt 0 0) + +let unit_name () = + match !backend with Lean -> "Unit" | Coq | FStar | HOL4 -> "unit" + +(** Small helper *) +let extract_arrow (fmt : F.formatter) () : unit = + if !Config.backend = Lean then F.pp_print_string fmt "→" + else F.pp_print_string fmt "->" + +let extract_const_generic (ctx : extraction_ctx) (fmt : F.formatter) + (inside : bool) (cg : const_generic) : unit = + match cg with + | ConstGenericGlobal id -> + let s = ctx_get_global id ctx in + F.pp_print_string fmt s + | ConstGenericValue v -> ctx.fmt.extract_literal fmt inside v + | ConstGenericVar id -> + let s = ctx_get_const_generic_var id ctx in + F.pp_print_string fmt s + +let extract_literal_type (ctx : extraction_ctx) (fmt : F.formatter) + (ty : literal_type) : unit = + match ty with + | Bool -> F.pp_print_string fmt ctx.fmt.bool_name + | Char -> F.pp_print_string fmt ctx.fmt.char_name + | Integer int_ty -> F.pp_print_string fmt (ctx.fmt.int_name int_ty) + +(** [inside] constrols whether we should add parentheses or not around type + applications (if [true] we add parentheses). + + [no_params_tys]: for all the types inside this set, do not print the type parameters. + This is used for HOL4. As polymorphism is uniform in HOL4, printing the + type parameters in the recursive definitions is useless (and actually + forbidden). + + For instance, where in F* we would write: + {[ + type list a = | Nil : list a | Cons : a -> list a -> list a + ]} + + In HOL4 we would simply write: + {[ + Datatype: + list = Nil 'a | Cons 'a list + End + ]} + *) +let rec extract_ty (ctx : extraction_ctx) (fmt : F.formatter) + (no_params_tys : TypeDeclId.Set.t) (inside : bool) (ty : ty) : unit = + let extract_rec = extract_ty ctx fmt no_params_tys in + match ty with + | Adt (type_id, generics) -> ( + let has_params = generics <> empty_generic_args in + match type_id with + | Tuple -> + (* This is a bit annoying, but in F*/Coq/HOL4 [()] is not the unit type: + * we have to write [unit]... *) + if generics.types = [] then F.pp_print_string fmt (unit_name ()) + else ( + F.pp_print_string fmt "("; + Collections.List.iter_link + (fun () -> + F.pp_print_space fmt (); + let product = + match !backend with + | FStar -> "&" + | Coq -> "*" + | Lean -> "×" + | HOL4 -> "#" + in + F.pp_print_string fmt product; + F.pp_print_space fmt ()) + (extract_rec true) generics.types; + F.pp_print_string fmt ")") + | AdtId _ | Assumed _ -> ( + (* HOL4 behaves differently. Where in Coq/FStar/Lean we would write: + `tree a b` + + In HOL4 we would write: + `('a, 'b) tree` + *) + match !backend with + | FStar | Coq | Lean -> + let print_paren = inside && has_params in + if print_paren then F.pp_print_string fmt "("; + (* TODO: for now, only the opaque *functions* are extracted in the + opaque module. The opaque *types* are assumed. *) + F.pp_print_string fmt (ctx_get_type type_id ctx); + (* We might need to filter the type arguments, if the type + is builtin (for instance, we filter the global allocator type + argument for `Vec`). *) + let generics = + match type_id with + | AdtId id -> ( + match + TypeDeclId.Map.find_opt id ctx.types_filter_type_args_map + with + | None -> generics + | Some filter -> + let types = List.combine filter generics.types in + let types = + List.filter_map + (fun (b, ty) -> if b then Some ty else None) + types + in + { generics with types }) + | _ -> generics + in + extract_generic_args ctx fmt no_params_tys generics; + if print_paren then F.pp_print_string fmt ")" + | HOL4 -> + let { types; const_generics; trait_refs } = generics in + (* Const generics are not supported in HOL4 *) + assert (const_generics = []); + let print_tys = + match type_id with + | AdtId id -> not (TypeDeclId.Set.mem id no_params_tys) + | Assumed _ -> true + | _ -> raise (Failure "Unreachable") + in + if types <> [] && print_tys then ( + let print_paren = List.length types > 1 in + if print_paren then F.pp_print_string fmt "("; + Collections.List.iter_link + (fun () -> + F.pp_print_string fmt ","; + F.pp_print_space fmt ()) + (extract_rec true) types; + if print_paren then F.pp_print_string fmt ")"; + F.pp_print_space fmt ()); + F.pp_print_string fmt (ctx_get_type type_id ctx); + if trait_refs <> [] then ( + F.pp_print_space fmt (); + Collections.List.iter_link (F.pp_print_space fmt) + (extract_trait_ref ctx fmt no_params_tys true) + trait_refs))) + | TypeVar vid -> F.pp_print_string fmt (ctx_get_type_var vid ctx) + | Literal lty -> extract_literal_type ctx fmt lty + | Arrow (arg_ty, ret_ty) -> + if inside then F.pp_print_string fmt "("; + extract_rec false arg_ty; + F.pp_print_space fmt (); + extract_arrow fmt (); + F.pp_print_space fmt (); + extract_rec false ret_ty; + if inside then F.pp_print_string fmt ")" + | TraitType (trait_ref, generics, type_name) -> ( + if !parameterize_trait_types then raise (Failure "Unimplemented") + else + let type_name = + ctx_get_trait_type trait_ref.trait_decl_ref.trait_decl_id type_name + ctx + in + let add_brackets (s : string) = + if !backend = Coq then "(" ^ s ^ ")" else s + in + (* There may be a special treatment depending on the instance id. + See the comments for {!extract_trait_instance_id_with_dot}. + TODO: there should be a cleaner way to do. The annoying thing + here is that if we project directly over the self clause, then + we have to be careful (we may not have to print the "Self."). + Otherwise, we can directly call {!extract_trait_ref}. + *) + match trait_ref.trait_id with + | Self -> + assert (generics = empty_generic_args); + assert (trait_ref.generics = empty_generic_args); + extract_trait_instance_id_with_dot ctx fmt no_params_tys false + trait_ref.trait_id; + F.pp_print_string fmt type_name + | _ -> + (* HOL4 doesn't have 1st class types *) + assert (!backend <> HOL4); + let use_brackets = generics <> empty_generic_args in + if use_brackets then F.pp_print_string fmt "("; + extract_trait_ref ctx fmt no_params_tys false trait_ref; + extract_generic_args ctx fmt no_params_tys generics; + if use_brackets then F.pp_print_string fmt ")"; + F.pp_print_string fmt ("." ^ add_brackets type_name)) + +and extract_trait_ref (ctx : extraction_ctx) (fmt : F.formatter) + (no_params_tys : TypeDeclId.Set.t) (inside : bool) (tr : trait_ref) : unit = + let use_brackets = tr.generics <> empty_generic_args && inside in + if use_brackets then F.pp_print_string fmt "("; + (* We may need to filter the parameters if the trait is builtin *) + let generics = + match tr.trait_id with + | TraitImpl id -> ( + match + TraitImplId.Map.find_opt id ctx.trait_impls_filter_type_args_map + with + | None -> tr.generics + | Some filter -> + let types = + List.filter_map + (fun (b, x) -> if b then Some x else None) + (List.combine filter tr.generics.types) + in + { tr.generics with types }) + | _ -> tr.generics + in + extract_trait_instance_id ctx fmt no_params_tys inside tr.trait_id; + extract_generic_args ctx fmt no_params_tys generics; + if use_brackets then F.pp_print_string fmt ")" + +and extract_trait_decl_ref (ctx : extraction_ctx) (fmt : F.formatter) + (no_params_tys : TypeDeclId.Set.t) (inside : bool) (tr : trait_decl_ref) : + unit = + let use_brackets = tr.decl_generics <> empty_generic_args && inside in + let name = ctx_get_trait_decl tr.trait_decl_id ctx in + if use_brackets then F.pp_print_string fmt "("; + F.pp_print_string fmt name; + (* There is something subtle here: the trait obligations for the implemented + trait are put inside the parent clauses, so we must ignore them here *) + let generics = { tr.decl_generics with trait_refs = [] } in + extract_generic_args ctx fmt no_params_tys generics; + if use_brackets then F.pp_print_string fmt ")" + +and extract_generic_args (ctx : extraction_ctx) (fmt : F.formatter) + (no_params_tys : TypeDeclId.Set.t) (generics : generic_args) : unit = + let { types; const_generics; trait_refs } = generics in + if !backend <> HOL4 then ( + if types <> [] then ( + F.pp_print_space fmt (); + Collections.List.iter_link (F.pp_print_space fmt) + (extract_ty ctx fmt no_params_tys true) + types); + if const_generics <> [] then ( + assert (!backend <> HOL4); + F.pp_print_space fmt (); + Collections.List.iter_link (F.pp_print_space fmt) + (extract_const_generic ctx fmt true) + const_generics)); + if trait_refs <> [] then ( + F.pp_print_space fmt (); + Collections.List.iter_link (F.pp_print_space fmt) + (extract_trait_ref ctx fmt no_params_tys true) + trait_refs) + +(** We sometimes need to ignore references to `Self` when generating the + code, espcially when we project associated items. For this reason we + have a special function for the cases where we project from an instance + id (e.g., `<Self as Foo>::foo` - note that in the extracted code, the + projections are often written with a dot '.'). + *) +and extract_trait_instance_id_with_dot (ctx : extraction_ctx) + (fmt : F.formatter) (no_params_tys : TypeDeclId.Set.t) (inside : bool) + (id : trait_instance_id) : unit = + match id with + | Self -> + (* There are two situations: + - we are extracting a declared item and need to refer to another + item (for instance, we are extracting a method signature and + need to refer to an associated type). + We directly refer to the other item (we extract trait declarations + as structures, so we can refer to their fields) + - we are extracting a provided method for a trait declaration. We + refer to the item in the self trait clause (see {!SelfTraitClauseId}). + + Remark: we can't get there for trait *implementations* because then the + types should have been normalized. + *) + if ctx.is_provided_method then + (* Provided method: use the trait self clause *) + let self_clause = ctx_get_trait_self_clause ctx in + F.pp_print_string fmt (self_clause ^ ".") + else + (* Declaration: nothing to print, we will directly refer to + the item. *) + () + | _ -> + (* Other cases *) + extract_trait_instance_id ctx fmt no_params_tys inside id; + F.pp_print_string fmt "." + +and extract_trait_instance_id (ctx : extraction_ctx) (fmt : F.formatter) + (no_params_tys : TypeDeclId.Set.t) (inside : bool) (id : trait_instance_id) + : unit = + let add_brackets (s : string) = if !backend = Coq then "(" ^ s ^ ")" else s in + match id with + | Self -> + (* This has a specific treatment depending on the item we're extracting + (associated type, etc.). We should have caught this elsewhere. *) + if !Config.fail_hard then + raise (Failure "Unexpected occurrence of `Self`") + else F.pp_print_string fmt "ERROR(\"Unexpected Self\")" + | TraitImpl id -> + let name = ctx_get_trait_impl id ctx in + F.pp_print_string fmt name + | Clause id -> + let name = ctx_get_local_trait_clause id ctx in + F.pp_print_string fmt name + | ParentClause (inst_id, decl_id, clause_id) -> + (* Use the trait decl id to lookup the name *) + let name = ctx_get_trait_parent_clause decl_id clause_id ctx in + extract_trait_instance_id_with_dot ctx fmt no_params_tys true inst_id; + F.pp_print_string fmt (add_brackets name) + | ItemClause (inst_id, decl_id, item_name, clause_id) -> + (* Use the trait decl id to lookup the name *) + let name = ctx_get_trait_item_clause decl_id item_name clause_id ctx in + extract_trait_instance_id_with_dot ctx fmt no_params_tys true inst_id; + F.pp_print_string fmt (add_brackets name) + | TraitRef trait_ref -> + extract_trait_ref ctx fmt no_params_tys inside trait_ref + | UnknownTrait _ -> + (* This is an error case *) + raise (Failure "Unexpected") + +(** Compute the names for all the top-level identifiers used in a type + definition (type name, variant names, field names, etc. but not type + parameters). + + We need to do this preemptively, beforce extracting any definition, + because of recursive definitions. + *) +let extract_type_decl_register_names (ctx : extraction_ctx) (def : type_decl) : + extraction_ctx = + (* Lookup the builtin information, if there is *) + let open ExtractBuiltin in + let sname = name_to_simple_name def.name in + let info = SimpleNameMap.find_opt sname (builtin_types_map ()) in + (* Register the filtering information, if there is *) + let ctx = + match info with + | Some { keep_params = Some keep; _ } -> + { + ctx with + types_filter_type_args_map = + TypeDeclId.Map.add def.def_id keep ctx.types_filter_type_args_map; + } + | _ -> ctx + in + (* Compute and register the type def name *) + let def_name = + match info with + | None -> ctx.fmt.type_name def.name + | Some info -> info.extract_name + in + let ctx = ctx_add (TypeId (AdtId def.def_id)) def_name ctx in + (* Compute and register: + * - the variant names, if this is an enumeration + * - the field names, if this is a structure + *) + let ctx = + match def.kind with + | Struct fields -> + (* Compute the names *) + let field_names, cons_name = + match info with + | None | Some { body_info = None; _ } -> + let field_names = + FieldId.mapi + (fun fid (field : field) -> + (fid, ctx.fmt.field_name def.name fid field.field_name)) + fields + in + let cons_name = ctx.fmt.struct_constructor def.name in + (field_names, cons_name) + | Some { body_info = Some (Struct (cons_name, field_names)); _ } -> + let field_names = + FieldId.mapi + (fun fid (field : field) -> + let rust_name = Option.get field.field_name in + let name = + snd (List.find (fun (n, _) -> n = rust_name) field_names) + in + (fid, name)) + fields + in + (field_names, cons_name) + | Some info -> + raise + (Failure + ("Invalid builtin information: " + ^ show_builtin_type_info info)) + in + (* Add the fields *) + let ctx = + List.fold_left + (fun ctx (fid, name) -> + ctx_add (FieldId (AdtId def.def_id, fid)) name ctx) + ctx field_names + in + (* Add the constructor name *) + ctx_add (StructId (AdtId def.def_id)) cons_name ctx + | Enum variants -> + let variant_names = + match info with + | None -> + VariantId.mapi + (fun variant_id (variant : variant) -> + let name = + ctx.fmt.variant_name def.name variant.variant_name + in + (* Add the type name prefix for Lean *) + let name = + if !Config.backend = Lean then + let type_name = ctx.fmt.type_name def.name in + type_name ^ "." ^ name + else name + in + (variant_id, name)) + variants + | Some { body_info = Some (Enum variant_infos); _ } -> + (* We need to compute the map from variant to variant *) + let variant_map = + StringMap.of_list + (List.map + (fun (info : builtin_enum_variant_info) -> + (info.rust_variant_name, info.extract_variant_name)) + variant_infos) + in + VariantId.mapi + (fun variant_id (variant : variant) -> + (variant_id, StringMap.find variant.variant_name variant_map)) + variants + | _ -> raise (Failure "Invalid builtin information") + in + List.fold_left + (fun ctx (vid, vname) -> + ctx_add (VariantId (AdtId def.def_id, vid)) vname ctx) + ctx variant_names + | Opaque -> + (* Nothing to do *) + ctx + in + (* Return *) + ctx + +(** Print the variants *) +let extract_type_decl_variant (ctx : extraction_ctx) (fmt : F.formatter) + (type_decl_group : TypeDeclId.Set.t) (type_name : string) + (type_params : string list) (cg_params : string list) (cons_name : string) + (fields : field list) : unit = + F.pp_print_space fmt (); + (* variant box *) + F.pp_open_hvbox fmt ctx.indent_incr; + (* [| Cons :] + * Note that we really don't want any break above so we print everything + * at once. *) + let opt_colon = if !backend <> HOL4 then " :" else "" in + F.pp_print_string fmt ("| " ^ cons_name ^ opt_colon); + let print_field (fid : FieldId.id) (f : field) (ctx : extraction_ctx) : + extraction_ctx = + F.pp_print_space fmt (); + (* Open the field box *) + F.pp_open_box fmt ctx.indent_incr; + (* Print the field names, if the backend accepts it. + * [ x :] + * Note that when printing fields, we register the field names as + * *variables*: they don't need to be unique at the top level. *) + let ctx = + match !backend with + | FStar -> ( + match f.field_name with + | None -> ctx + | Some field_name -> + let var_id = VarId.of_int (FieldId.to_int fid) in + let field_name = + ctx.fmt.var_basename ctx.names_maps.names_map.names_set + (Some field_name) f.field_ty + in + let ctx, field_name = ctx_add_var field_name var_id ctx in + F.pp_print_string fmt (field_name ^ " :"); + F.pp_print_space fmt (); + ctx) + | Coq | Lean | HOL4 -> ctx + in + (* Print the field type *) + let inside = !backend = HOL4 in + extract_ty ctx fmt type_decl_group inside f.field_ty; + (* Print the arrow [->] *) + if !backend <> HOL4 then ( + F.pp_print_space fmt (); + extract_arrow fmt ()); + (* Close the field box *) + F.pp_close_box fmt (); + (* Return *) + ctx + in + (* Print the fields *) + let fields = FieldId.mapi (fun fid f -> (fid, f)) fields in + let _ = + List.fold_left (fun ctx (fid, f) -> print_field fid f ctx) ctx fields + in + (* Sanity check: HOL4 doesn't support const generics *) + assert (cg_params = [] || !backend <> HOL4); + (* Print the final type *) + if !backend <> HOL4 then ( + F.pp_print_space fmt (); + F.pp_open_hovbox fmt 0; + F.pp_print_string fmt type_name; + List.iter + (fun p -> + F.pp_print_space fmt (); + F.pp_print_string fmt p) + (List.append type_params cg_params); + F.pp_close_box fmt ()); + (* Close the variant box *) + F.pp_close_box fmt () + +(* TODO: we don' need the [def_name] paramter: it can be retrieved from the context *) +let extract_type_decl_enum_body (ctx : extraction_ctx) (fmt : F.formatter) + (type_decl_group : TypeDeclId.Set.t) (def : type_decl) (def_name : string) + (type_params : string list) (cg_params : string list) + (variants : variant list) : unit = + (* We want to generate a definition which looks like this (taking F* as example): + {[ + type list a = | Cons : a -> list a -> list a | Nil : list a + ]} + + If there isn't enough space on one line: + {[ + type s = + | Cons : a -> list a -> list a + | Nil : list a + ]} + + And if we need to write the type of a variant on several lines: + {[ + type s = + | Cons : + a -> + list a -> + list a + | Nil : list a + ]} + + Finally, it is possible to give names to the variant fields in Rust. + In this situation, we generate a definition like this: + {[ + type s = + | Cons : hd:a -> tl:list a -> list a + | Nil : list a + ]} + + Note that we already printed: [type s =] + *) + let print_variant _variant_id (v : variant) = + (* We don't lookup the name, because it may have a prefix for the type + id (in the case of Lean) *) + let cons_name = ctx.fmt.variant_name def.name v.variant_name in + let fields = v.fields in + extract_type_decl_variant ctx fmt type_decl_group def_name type_params + cg_params cons_name fields + in + (* Print the variants *) + let variants = VariantId.mapi (fun vid v -> (vid, v)) variants in + List.iter (fun (vid, v) -> print_variant vid v) variants + +let extract_type_decl_struct_body (ctx : extraction_ctx) (fmt : F.formatter) + (type_decl_group : TypeDeclId.Set.t) (kind : decl_kind) (def : type_decl) + (type_params : string list) (cg_params : string list) (fields : field list) + : unit = + (* We want to generate a definition which looks like this (taking F* as example): + {[ + type t = { x : int; y : bool; } + ]} + + If there isn't enough space on one line: + {[ + type t = + { + x : int; y : bool; + } + ]} + + And if there is even less space: + {[ + type t = + { + x : int; + y : bool; + } + ]} + + Also, in case there are no fields, we need to define the type as [unit] + ([type t = {}] doesn't work in F* ). + + Coq: + ==== + We need to define the constructor name upon defining the struct (record, in Coq). + The syntex is: + {[ + Record Foo = mkFoo { x : int; y : bool; }. + }] + + Also, Coq doesn't support groups of mutually recursive inductives and records. + This is fine, because we can then define records as inductives, and leverage + the fact that when record fields are accessed, the records are symbolically + expanded which introduces let bindings of the form: [let RecordCons ... = x in ...]. + As a consequence, we never use the record projectors (unless we reconstruct + them in the micro passes of course). + + HOL4: + ===== + Type definitions are written as follows: + {[ + Datatype: + tree = + TLeaf 'a + | TNode node ; + + node = + Node (tree list) + End + ]} + *) + (* Note that we already printed: [type t =] *) + let is_rec = decl_is_from_rec_group kind in + let _ = + if !backend = FStar && fields = [] then ( + F.pp_print_space fmt (); + F.pp_print_string fmt (unit_name ())) + else if !backend = Lean && fields = [] then () + (* If the definition is recursive, we may need to extract it as an inductive + (instead of a record). We start with the "normal" case: we extract it + as a record. *) + else if (not is_rec) || (!backend <> Coq && !backend <> Lean) then ( + if !backend <> Lean then F.pp_print_space fmt (); + (* If Coq: print the constructor name *) + (* TODO: remove superfluous test not is_rec below *) + if !backend = Coq && not is_rec then ( + F.pp_print_string fmt (ctx_get_struct (AdtId def.def_id) ctx); + F.pp_print_string fmt " "); + (match !backend with + | Lean -> () + | FStar | Coq -> F.pp_print_string fmt "{" + | HOL4 -> F.pp_print_string fmt "<|"); + F.pp_print_break fmt 1 ctx.indent_incr; + (* The body itself *) + (* Open a box for the body *) + (match !backend with + | Coq | FStar | HOL4 -> F.pp_open_hvbox fmt 0 + | Lean -> F.pp_open_vbox fmt 0); + (* Print the fields *) + let print_field (field_id : FieldId.id) (f : field) : unit = + let field_name = ctx_get_field (AdtId def.def_id) field_id ctx in + (* Open a box for the field *) + F.pp_open_box fmt ctx.indent_incr; + F.pp_print_string fmt field_name; + F.pp_print_space fmt (); + F.pp_print_string fmt ":"; + F.pp_print_space fmt (); + extract_ty ctx fmt type_decl_group false f.field_ty; + if !backend <> Lean then F.pp_print_string fmt ";"; + (* Close the box for the field *) + F.pp_close_box fmt () + in + let fields = FieldId.mapi (fun fid f -> (fid, f)) fields in + Collections.List.iter_link (F.pp_print_space fmt) + (fun (fid, f) -> print_field fid f) + fields; + (* Close the box for the body *) + F.pp_close_box fmt (); + match !backend with + | Lean -> () + | FStar | Coq -> + F.pp_print_space fmt (); + F.pp_print_string fmt "}" + | HOL4 -> + F.pp_print_space fmt (); + F.pp_print_string fmt "|>") + else ( + (* We extract for Coq or Lean, and we have a recursive record, or a record in + a group of mutually recursive types: we extract it as an inductive type *) + assert (is_rec && (!backend = Coq || !backend = Lean)); + (* Small trick: in Lean we use namespaces, meaning we don't need to prefix + the constructor name with the name of the type at definition site, + i.e., instead of generating `inductive Foo := | MkFoo ...` like in Coq + we generate `inductive Foo := | mk ... *) + let cons_name = + if !backend = Lean then "mk" else ctx_get_struct (AdtId def.def_id) ctx + in + let def_name = ctx_get_local_type def.def_id ctx in + extract_type_decl_variant ctx fmt type_decl_group def_name type_params + cg_params cons_name fields) + in + () + +(** Extract a nestable, muti-line comment *) +let extract_comment (fmt : F.formatter) (sl : string list) : unit = + (* Delimiters, space after we break a line *) + let ld, space, rd = + match !backend with + | Coq | FStar | HOL4 -> ("(** ", 4, " *)") + | Lean -> ("/- ", 3, " -/") + in + F.pp_open_vbox fmt space; + F.pp_print_string fmt ld; + (match sl with + | [] -> () + | s :: sl -> + F.pp_print_string fmt s; + List.iter + (fun s -> + F.pp_print_space fmt (); + F.pp_print_string fmt s) + sl); + F.pp_print_string fmt rd; + F.pp_close_box fmt () + +let extract_trait_clause_type (ctx : extraction_ctx) (fmt : F.formatter) + (no_params_tys : TypeDeclId.Set.t) (clause : trait_clause) : unit = + let trait_name = ctx_get_trait_decl clause.trait_id ctx in + F.pp_print_string fmt trait_name; + extract_generic_args ctx fmt no_params_tys clause.generics + +(** Insert a space, if necessary *) +let insert_req_space (fmt : F.formatter) (space : bool ref) : unit = + if !space then space := false else F.pp_print_space fmt () + +(** Extract the trait self clause. + + We add the trait self clause for provided methods (see {!TraitSelfClauseId}). + *) +let extract_trait_self_clause (insert_req_space : unit -> unit) + (ctx : extraction_ctx) (fmt : F.formatter) (trait_decl : trait_decl) + (params : string list) : unit = + insert_req_space (); + F.pp_print_string fmt "("; + let self_clause = ctx_get_trait_self_clause ctx in + F.pp_print_string fmt self_clause; + F.pp_print_space fmt (); + F.pp_print_string fmt ":"; + F.pp_print_space fmt (); + let trait_id = ctx_get_trait_decl trait_decl.def_id ctx in + F.pp_print_string fmt trait_id; + List.iter + (fun p -> + F.pp_print_space fmt (); + F.pp_print_string fmt p) + params; + F.pp_print_string fmt ")" + +(** + - [trait_decl]: if [Some], it means we are extracting the generics for a provided + method and need to insert a trait self clause (see {!TraitSelfClauseId}). + *) +let extract_generic_params (ctx : extraction_ctx) (fmt : F.formatter) + (no_params_tys : TypeDeclId.Set.t) ?(use_forall = false) + ?(use_forall_use_sep = true) ?(use_arrows = false) + ?(as_implicits : bool = false) ?(space : bool ref option = None) + ?(trait_decl : trait_decl option = None) (generics : generic_params) + (type_params : string list) (cg_params : string list) + (trait_clauses : string list) : unit = + let all_params = List.concat [ type_params; cg_params; trait_clauses ] in + (* HOL4 doesn't support const generics *) + assert (cg_params = [] || !backend <> HOL4); + let left_bracket (implicit : bool) = + if implicit && !backend <> FStar then F.pp_print_string fmt "{" + else F.pp_print_string fmt "(" + in + let right_bracket (implicit : bool) = + if implicit && !backend <> FStar then F.pp_print_string fmt "}" + else F.pp_print_string fmt ")" + in + let print_implicit_symbol (implicit : bool) = + if implicit && !backend = FStar then F.pp_print_string fmt "#" else () + in + let insert_req_space () = + match space with + | None -> F.pp_print_space fmt () + | Some space -> insert_req_space fmt space + in + (* Print the type/const generic parameters *) + if all_params <> [] then ( + if use_forall then ( + if use_forall_use_sep then ( + insert_req_space (); + F.pp_print_string fmt ":"); + insert_req_space (); + F.pp_print_string fmt "forall"); + (* Small helper - we may need to split the parameters *) + let print_generics (as_implicits : bool) (type_params : string list) + (const_generics : const_generic_var list) + (trait_clauses : trait_clause list) : unit = + (* Note that in HOL4 we don't print the type parameters. *) + if !backend <> HOL4 then ( + (* Print the type parameters *) + if type_params <> [] then ( + insert_req_space (); + (* ( *) + left_bracket as_implicits; + List.iter + (fun s -> + print_implicit_symbol as_implicits; + F.pp_print_string fmt s; + F.pp_print_space fmt ()) + type_params; + F.pp_print_string fmt ":"; + F.pp_print_space fmt (); + F.pp_print_string fmt (type_keyword ()); + (* ) *) + right_bracket as_implicits; + if use_arrows then ( + F.pp_print_space fmt (); + F.pp_print_string fmt "->")); + (* Print the const generic parameters *) + List.iter + (fun (var : const_generic_var) -> + insert_req_space (); + (* ( *) + left_bracket as_implicits; + let n = ctx_get_const_generic_var var.index ctx in + print_implicit_symbol as_implicits; + F.pp_print_string fmt n; + F.pp_print_space fmt (); + F.pp_print_string fmt ":"; + F.pp_print_space fmt (); + extract_literal_type ctx fmt var.ty; + (* ) *) + right_bracket as_implicits; + if use_arrows then ( + F.pp_print_space fmt (); + F.pp_print_string fmt "->")) + const_generics); + (* Print the trait clauses *) + List.iter + (fun (clause : trait_clause) -> + insert_req_space (); + (* ( *) + left_bracket as_implicits; + let n = ctx_get_local_trait_clause clause.clause_id ctx in + print_implicit_symbol as_implicits; + F.pp_print_string fmt n; + F.pp_print_space fmt (); + F.pp_print_string fmt ":"; + F.pp_print_space fmt (); + extract_trait_clause_type ctx fmt no_params_tys clause; + (* ) *) + right_bracket as_implicits; + if use_arrows then ( + F.pp_print_space fmt (); + F.pp_print_string fmt "->")) + trait_clauses + in + (* If we extract the generics for a provided method for a trait declaration + (indicated by the trait decl given as input), we need to split the generics: + - we print the generics for the trait decl + - we print the trait self clause + - we print the generics for the trait method + *) + match trait_decl with + | None -> + print_generics as_implicits type_params generics.const_generics + generics.trait_clauses + | Some trait_decl -> + (* Split the generics between the generics specific to the trait decl + and those specific to the trait method *) + let open Collections.List in + let dtype_params, mtype_params = + split_at type_params (length trait_decl.generics.types) + in + let dcgs, mcgs = + split_at generics.const_generics + (length trait_decl.generics.const_generics) + in + let dtrait_clauses, mtrait_clauses = + split_at generics.trait_clauses + (length trait_decl.generics.trait_clauses) + in + (* Extract the trait decl generics - note that we can always deduce + those parameters from the trait self clause: for this reason + they are always implicit *) + print_generics true dtype_params dcgs dtrait_clauses; + (* Extract the trait self clause *) + let params = + concat + [ + dtype_params; + map + (fun (cg : const_generic_var) -> + ctx_get_const_generic_var cg.index ctx) + dcgs; + map + (fun c -> ctx_get_local_trait_clause c.clause_id ctx) + dtrait_clauses; + ] + in + extract_trait_self_clause insert_req_space ctx fmt trait_decl params; + (* Extract the method generics *) + print_generics as_implicits mtype_params mcgs mtrait_clauses) + +(** Extract a type declaration. + + This function is for all type declarations and all backends **at the exception** + of opaque (assumed/declared) types format4 HOL4. + + See {!extract_type_decl}. + *) +let extract_type_decl_gen (ctx : extraction_ctx) (fmt : F.formatter) + (type_decl_group : TypeDeclId.Set.t) (kind : decl_kind) (def : type_decl) + (extract_body : bool) : unit = + (* Sanity check *) + assert (extract_body || !backend <> HOL4); + let type_kind = + if extract_body then + match def.kind with + | Struct _ -> Some Struct + | Enum _ -> Some Enum + | Opaque -> None + else None + in + (* If in Coq and the declaration is opaque, it must have the shape: + [Axiom Ident : forall (T0 ... Tn : Type) (N0 : ...) ... (Nn : ...), ... -> ... -> ...]. + + The boolean [is_opaque_coq] is used to detect this case. + *) + let is_opaque = type_kind = None in + let is_opaque_coq = !backend = Coq && is_opaque in + let use_forall = is_opaque_coq && def.generics <> empty_generic_params in + (* Retrieve the definition name *) + let def_name = ctx_get_local_type def.def_id ctx in + (* Add the type and const generic params - note that we need those bindings only for the + * body translation (they are not top-level) *) + let ctx_body, type_params, cg_params, trait_clauses = + ctx_add_generic_params def.generics ctx + in + (* Add a break before *) + if !backend <> HOL4 || not (decl_is_first_from_group kind) then + F.pp_print_break fmt 0 0; + (* Print a comment to link the extracted type to its original rust definition *) + extract_comment fmt [ "[" ^ Print.name_to_string def.name ^ "]" ]; + F.pp_print_break fmt 0 0; + (* Open a box for the definition, so that whenever possible it gets printed on + * one line. Note however that in the case of Lean line breaks are important + * for parsing: we thus use a hovbox. *) + (match !backend with + | Coq | FStar | HOL4 -> F.pp_open_hvbox fmt 0 + | Lean -> F.pp_open_vbox fmt 0); + (* Open a box for "type TYPE_NAME (TYPE_PARAMS CONST_GEN_PARAMS) =" *) + F.pp_open_hovbox fmt ctx.indent_incr; + (* > "type TYPE_NAME" *) + let qualif = ctx.fmt.type_decl_kind_to_qualif kind type_kind in + (match qualif with + | Some qualif -> F.pp_print_string fmt (qualif ^ " " ^ def_name) + | None -> F.pp_print_string fmt def_name); + (* HOL4 doesn't support const generics, and type definitions in HOL4 don't + support trait clauses *) + assert ((cg_params = [] && trait_clauses = []) || !backend <> HOL4); + (* Print the generic parameters *) + extract_generic_params ctx_body fmt type_decl_group ~use_forall def.generics + type_params cg_params trait_clauses; + (* Print the "=" if we extract the body*) + if extract_body then ( + F.pp_print_space fmt (); + let eq = + match !backend with + | FStar -> "=" + | Coq -> ":=" + | Lean -> + if type_kind = Some Struct && kind = SingleNonRec then "where" + else ":=" + | HOL4 -> "=" + in + F.pp_print_string fmt eq) + else ( + (* Otherwise print ": Type", unless it is the HOL4 backend (in + which case we declare the type with `new_type`) *) + if use_forall then F.pp_print_string fmt "," + else ( + F.pp_print_space fmt (); + F.pp_print_string fmt ":"); + F.pp_print_space fmt (); + F.pp_print_string fmt (type_keyword ())); + (* Close the box for "type TYPE_NAME (TYPE_PARAMS) =" *) + F.pp_close_box fmt (); + (if extract_body then + match def.kind with + | Struct fields -> + extract_type_decl_struct_body ctx_body fmt type_decl_group kind def + type_params cg_params fields + | Enum variants -> + extract_type_decl_enum_body ctx_body fmt type_decl_group def def_name + type_params cg_params variants + | Opaque -> raise (Failure "Unreachable")); + (* Add the definition end delimiter *) + if !backend = HOL4 && decl_is_not_last_from_group kind then ( + F.pp_print_space fmt (); + F.pp_print_string fmt ";") + else if !backend = Coq && decl_is_last_from_group kind then ( + (* This is actually an end of group delimiter. For aesthetic reasons + we print it here instead of in {!end_type_decl_group}. *) + F.pp_print_cut fmt (); + F.pp_print_string fmt "."); + (* Close the box for the definition *) + F.pp_close_box fmt (); + (* Add breaks to insert new lines between definitions *) + if !backend <> HOL4 || decl_is_not_last_from_group kind then + F.pp_print_break fmt 0 0 + +(** Extract an opaque type declaration to HOL4. + + Remark (SH): having to treat this specific case separately is very annoying, + but I could not find a better way. + *) +let extract_type_decl_hol4_opaque (ctx : extraction_ctx) (fmt : F.formatter) + (def : type_decl) : unit = + (* Retrieve the definition name *) + let def_name = ctx_get_local_type def.def_id ctx in + (* Generic parameters are unsupported *) + assert (def.generics.const_generics = []); + (* Trait clauses on type definitions are unsupported *) + assert (def.generics.trait_clauses = []); + (* Types *) + (* Count the number of parameters *) + let num_params = List.length def.generics.types in + (* Generate the declaration *) + F.pp_print_space fmt (); + F.pp_print_string fmt + ("val _ = new_type (\"" ^ def_name ^ "\", " ^ string_of_int num_params ^ ")"); + F.pp_print_space fmt () + +(** Extract an empty record type declaration to HOL4. + + Empty records are not supported in HOL4, so we extract them as type + abbreviations to the unit type. + + Remark (SH): having to treat this specific case separately is very annoying, + but I could not find a better way. + *) +let extract_type_decl_hol4_empty_record (ctx : extraction_ctx) + (fmt : F.formatter) (def : type_decl) : unit = + (* Retrieve the definition name *) + let def_name = ctx_get_local_type def.def_id ctx in + (* Sanity check *) + assert (def.generics = empty_generic_params); + (* Generate the declaration *) + F.pp_print_space fmt (); + F.pp_print_string fmt ("Type " ^ def_name ^ " = “: unit”"); + F.pp_print_space fmt () + +(** Extract a type declaration. + + Note that all the names used for extraction should already have been + registered. + + This function should be inserted between calls to {!start_type_decl_group} + and {!end_type_decl_group}. + *) +let extract_type_decl (ctx : extraction_ctx) (fmt : F.formatter) + (type_decl_group : TypeDeclId.Set.t) (kind : decl_kind) (def : type_decl) : + unit = + let extract_body = + match kind with + | SingleNonRec | SingleRec | MutRecFirst | MutRecInner | MutRecLast -> true + | Assumed | Declared -> false + in + if extract_body then + if !backend = HOL4 && is_empty_record_type_decl def then + extract_type_decl_hol4_empty_record ctx fmt def + else extract_type_decl_gen ctx fmt type_decl_group kind def extract_body + else + match !backend with + | FStar | Coq | Lean -> + extract_type_decl_gen ctx fmt type_decl_group kind def extract_body + | HOL4 -> extract_type_decl_hol4_opaque ctx fmt def + +(** Generate a [Argument] instruction in Coq to allow omitting implicit + arguments for variants, fields, etc.. + + For instance, provided we have this definition: + {[ + Inductive result A := + | Return : A -> result A + | Fail_ : error -> result A. + ]} + + We may want to generate those instructions: + {[ + Arguments Return {_} a. + Arguments Fail_ {_}. + ]} + *) +let extract_coq_arguments_instruction (ctx : extraction_ctx) (fmt : F.formatter) + (cons_name : string) (num_implicit_params : int) : unit = + (* Add a break before *) + F.pp_print_break fmt 0 0; + (* Open a box *) + F.pp_open_hovbox fmt ctx.indent_incr; + F.pp_print_break fmt 0 0; + F.pp_print_string fmt "Arguments"; + F.pp_print_space fmt (); + F.pp_print_string fmt cons_name; + (* Print the type/const params and the trait clauses (`{T}`) *) + F.pp_print_space fmt (); + F.pp_print_string fmt "{"; + Collections.List.iter_times num_implicit_params (fun () -> + F.pp_print_space fmt (); + F.pp_print_string fmt "_"); + F.pp_print_space fmt (); + F.pp_print_string fmt "}."; + + (* Close the box *) + F.pp_close_box fmt () + +(** Auxiliary function. + + Generate [Arguments] instructions in Coq for type definitions. + *) +let extract_type_decl_coq_arguments (ctx : extraction_ctx) (fmt : F.formatter) + (kind : decl_kind) (decl : type_decl) : unit = + assert (!backend = Coq); + (* Generating the [Arguments] instructions is useful only if there are parameters *) + let num_params = + List.length decl.generics.types + + List.length decl.generics.const_generics + + List.length decl.generics.trait_clauses + in + if num_params = 0 then () + else + (* Generate the [Arguments] instruction *) + match decl.kind with + | Opaque -> () + | Struct fields -> + let adt_id = AdtId decl.def_id in + (* Generate the instruction for the record constructor *) + let cons_name = ctx_get_struct adt_id ctx in + extract_coq_arguments_instruction ctx fmt cons_name num_params; + (* Generate the instruction for the record projectors, if there are *) + let is_rec = decl_is_from_rec_group kind in + if not is_rec then + FieldId.iteri + (fun fid _ -> + let cons_name = ctx_get_field adt_id fid ctx in + extract_coq_arguments_instruction ctx fmt cons_name num_params) + fields; + (* Add breaks to insert new lines between definitions *) + F.pp_print_break fmt 0 0 + | Enum variants -> + (* Generate the instructions *) + VariantId.iteri + (fun vid (_ : variant) -> + let cons_name = ctx_get_variant (AdtId decl.def_id) vid ctx in + extract_coq_arguments_instruction ctx fmt cons_name num_params) + variants; + (* Add breaks to insert new lines between definitions *) + F.pp_print_break fmt 0 0 + +(** Auxiliary function. + + Generate field projectors in Coq. + + Sometimes we extract records as inductives in Coq: when this happens we + have to define the field projectors afterwards. + *) +let extract_type_decl_record_field_projectors (ctx : extraction_ctx) + (fmt : F.formatter) (kind : decl_kind) (decl : type_decl) : unit = + assert (!backend = Coq); + match decl.kind with + | Opaque | Enum _ -> () + | Struct fields -> + (* Records are extracted as inductives only if they are recursive *) + let is_rec = decl_is_from_rec_group kind in + if is_rec then + (* Add the type params *) + let ctx, type_params, cg_params, trait_clauses = + ctx_add_generic_params decl.generics ctx + in + let ctx, record_var = ctx_add_var "x" (VarId.of_int 0) ctx in + let ctx, field_var = ctx_add_var "x" (VarId.of_int 1) ctx in + let def_name = ctx_get_local_type decl.def_id ctx in + let cons_name = ctx_get_struct (AdtId decl.def_id) ctx in + let extract_field_proj (field_id : FieldId.id) (_ : field) : unit = + F.pp_print_space fmt (); + (* Outer box for the projector definition *) + F.pp_open_hvbox fmt 0; + (* Inner box for the projector definition *) + F.pp_open_hvbox fmt ctx.indent_incr; + (* Open a box for the [Definition PROJ ... :=] *) + F.pp_open_hovbox fmt ctx.indent_incr; + F.pp_print_string fmt "Definition"; + F.pp_print_space fmt (); + let field_name = ctx_get_field (AdtId decl.def_id) field_id ctx in + F.pp_print_string fmt field_name; + (* Print the generics *) + let as_implicits = true in + extract_generic_params ctx fmt TypeDeclId.Set.empty ~as_implicits + decl.generics type_params cg_params trait_clauses; + (* Print the record parameter *) + F.pp_print_space fmt (); + F.pp_print_string fmt "("; + F.pp_print_string fmt record_var; + F.pp_print_space fmt (); + F.pp_print_string fmt ":"; + F.pp_print_space fmt (); + F.pp_print_string fmt def_name; + List.iter + (fun p -> + F.pp_print_space fmt (); + F.pp_print_string fmt p) + type_params; + F.pp_print_string fmt ")"; + (* *) + F.pp_print_space fmt (); + F.pp_print_string fmt ":="; + (* Close the box for the [Definition PROJ ... :=] *) + F.pp_close_box fmt (); + F.pp_print_space fmt (); + (* Open a box for the whole match *) + F.pp_open_hvbox fmt 0; + (* Open a box for the [match ... with] *) + F.pp_open_hovbox fmt ctx.indent_incr; + F.pp_print_string fmt "match"; + F.pp_print_space fmt (); + F.pp_print_string fmt record_var; + F.pp_print_space fmt (); + F.pp_print_string fmt "with"; + (* Close the box for the [match ... with] *) + F.pp_close_box fmt (); + + (* Open a box for the branch *) + F.pp_open_hovbox fmt ctx.indent_incr; + (* Print the match branch *) + F.pp_print_space fmt (); + F.pp_print_string fmt "|"; + F.pp_print_space fmt (); + F.pp_print_string fmt cons_name; + FieldId.iteri + (fun id _ -> + F.pp_print_space fmt (); + if field_id = id then F.pp_print_string fmt field_var + else F.pp_print_string fmt "_") + fields; + F.pp_print_space fmt (); + F.pp_print_string fmt "=>"; + F.pp_print_space fmt (); + F.pp_print_string fmt field_var; + (* Close the box for the branch *) + F.pp_close_box fmt (); + (* Print the [end] *) + F.pp_print_space fmt (); + F.pp_print_string fmt "end"; + (* Close the box for the whole match *) + F.pp_close_box fmt (); + (* Close the inner box projector *) + F.pp_close_box fmt (); + (* If Coq: end the definition with a "." *) + if !backend = Coq then ( + F.pp_print_cut fmt (); + F.pp_print_string fmt "."); + (* Close the outer box projector *) + F.pp_close_box fmt (); + (* Add breaks to insert new lines between definitions *) + F.pp_print_break fmt 0 0 + in + + let extract_proj_notation (field_id : FieldId.id) (_ : field) : unit = + F.pp_print_space fmt (); + (* Outer box for the projector definition *) + F.pp_open_hvbox fmt 0; + (* Inner box for the projector definition *) + F.pp_open_hovbox fmt ctx.indent_incr; + let ctx, record_var = ctx_add_var "x" (VarId.of_int 0) ctx in + F.pp_print_string fmt "Notation"; + F.pp_print_space fmt (); + let field_name = ctx_get_field (AdtId decl.def_id) field_id ctx in + F.pp_print_string fmt ("\"" ^ record_var ^ " .(" ^ field_name ^ ")\""); + F.pp_print_space fmt (); + F.pp_print_string fmt ":="; + F.pp_print_space fmt (); + F.pp_print_string fmt "("; + F.pp_print_string fmt field_name; + F.pp_print_space fmt (); + F.pp_print_string fmt record_var; + F.pp_print_string fmt ")"; + F.pp_print_space fmt (); + F.pp_print_string fmt "(at level 9)"; + (* Close the inner box projector *) + F.pp_close_box fmt (); + (* If Coq: end the definition with a "." *) + if !backend = Coq then ( + F.pp_print_cut fmt (); + F.pp_print_string fmt "."); + (* Close the outer box projector *) + F.pp_close_box fmt (); + (* Add breaks to insert new lines between definitions *) + F.pp_print_break fmt 0 0 + in + + let extract_field_proj_and_notation (field_id : FieldId.id) + (field : field) : unit = + extract_field_proj field_id field; + extract_proj_notation field_id field + in + + FieldId.iteri extract_field_proj_and_notation fields + +(** Extract extra information for a type (e.g., [Arguments] instructions in Coq). + + Note that all the names used for extraction should already have been + registered. + *) +let extract_type_decl_extra_info (ctx : extraction_ctx) (fmt : F.formatter) + (kind : decl_kind) (decl : type_decl) : unit = + match !backend with + | FStar | Lean | HOL4 -> () + | Coq -> + extract_type_decl_coq_arguments ctx fmt kind decl; + extract_type_decl_record_field_projectors ctx fmt kind decl + +(** Extract the state type declaration. *) +let extract_state_type (fmt : F.formatter) (ctx : extraction_ctx) + (kind : decl_kind) : unit = + (* Add a break before *) + F.pp_print_break fmt 0 0; + (* Print a comment *) + extract_comment fmt [ "The state type used in the state-error monad" ]; + F.pp_print_break fmt 0 0; + (* Open a box for the definition, so that whenever possible it gets printed on + * one line *) + F.pp_open_hvbox fmt 0; + (* Retrieve the name *) + let state_name = ctx_get_assumed_type State ctx in + (* The syntax for Lean and Coq is almost identical. *) + let print_axiom () = + let axiom = + match !backend with + | Coq -> "Axiom" + | Lean -> "axiom" + | FStar | HOL4 -> raise (Failure "Unexpected") + in + F.pp_print_string fmt axiom; + F.pp_print_space fmt (); + F.pp_print_string fmt state_name; + F.pp_print_space fmt (); + F.pp_print_string fmt ":"; + F.pp_print_space fmt (); + F.pp_print_string fmt "Type"; + if !backend = Coq then F.pp_print_string fmt "." + in + (* The kind should be [Assumed] or [Declared] *) + (match kind with + | SingleNonRec | SingleRec | MutRecFirst | MutRecInner | MutRecLast -> + raise (Failure "Unexpected") + | Assumed -> ( + match !backend with + | FStar -> + F.pp_print_string fmt "assume"; + F.pp_print_space fmt (); + F.pp_print_string fmt "type"; + F.pp_print_space fmt (); + F.pp_print_string fmt state_name; + F.pp_print_space fmt (); + F.pp_print_string fmt ":"; + F.pp_print_space fmt (); + F.pp_print_string fmt "Type0" + | HOL4 -> + F.pp_print_string fmt ("val _ = new_type (\"" ^ state_name ^ "\", 0)") + | Coq | Lean -> print_axiom ()) + | Declared -> ( + match !backend with + | FStar -> + F.pp_print_string fmt "val"; + F.pp_print_space fmt (); + F.pp_print_string fmt state_name; + F.pp_print_space fmt (); + F.pp_print_string fmt ":"; + F.pp_print_space fmt (); + F.pp_print_string fmt "Type0" + | HOL4 -> + F.pp_print_string fmt ("val _ = new_type (\"" ^ state_name ^ "\", 0)") + | Coq | Lean -> print_axiom ())); + (* Close the box for the definition *) + F.pp_close_box fmt (); + (* Add breaks to insert new lines between definitions *) + F.pp_print_break fmt 0 0 diff --git a/compiler/FunsAnalysis.ml b/compiler/FunsAnalysis.ml index b72fa078..e17ea16f 100644 --- a/compiler/FunsAnalysis.ml +++ b/compiler/FunsAnalysis.ml @@ -57,12 +57,26 @@ let analyze_module (m : crate) (funs_map : fun_decl FunDeclId.Map.t) let stateful = ref false in let can_diverge = ref false in let is_rec = ref false in + let group_has_builtin_info = ref false in + + (* We have some specialized knowledge of some library functions; we don't + have any more custom treatment than this, and these functions can be modeled + suitably in Primitives.fst, rather than special-casing for them all the + way. *) + let get_builtin_info (f : fun_decl) : ExtractBuiltin.effect_info option = + let open ExtractBuiltin in + let name = name_to_simple_name f.name in + SimpleNameMap.find_opt name builtin_fun_effects_map + in + (* JP: Why not use a reduce visitor here with a tuple of the values to be + computed? *) let visit_fun (f : fun_decl) : unit = let obj = object (self) inherit [_] iter_statement as super method may_fail b = can_fail := !can_fail || b + method maybe_stateful b = stateful := !stateful || b method! visit_Assert env a = self#may_fail true; @@ -70,14 +84,14 @@ let analyze_module (m : crate) (funs_map : fun_decl FunDeclId.Map.t) method! visit_rvalue _env rv = match rv with - | Use _ | Ref _ | Global _ | Discriminant _ | Aggregate _ -> () + | Use _ | RvRef _ | Global _ | Discriminant _ | Aggregate _ -> () | UnaryOp (uop, _) -> can_fail := EU.unop_can_fail uop || !can_fail | BinaryOp (bop, _, _) -> can_fail := EU.binop_can_fail bop || !can_fail method! visit_Call env call = - (match call.func with - | Regular id -> + (match call.func.func with + | FunId (Regular id) -> if FunDeclId.Set.mem id fun_ids then ( can_diverge := true; is_rec := true) @@ -86,9 +100,14 @@ let analyze_module (m : crate) (funs_map : fun_decl FunDeclId.Map.t) self#may_fail info.can_fail; stateful := !stateful || info.stateful; can_diverge := !can_diverge || info.can_diverge - | Assumed id -> + | FunId (Assumed id) -> (* None of the assumed functions can diverge nor are considered stateful *) - can_fail := !can_fail || Assumed.assumed_can_fail id); + can_fail := !can_fail || Assumed.assumed_fun_can_fail id + | TraitMethod _ -> + (* We consider trait functions can fail, but can not diverge and are not stateful. + TODO: this may cause issues if we use use a fuel parameter. + *) + can_fail := true); super#visit_Call env call method! visit_Panic env = @@ -102,11 +121,21 @@ let analyze_module (m : crate) (funs_map : fun_decl FunDeclId.Map.t) in (* Sanity check: global bodies don't contain stateful calls *) assert ((not f.is_global_decl_body) || not !stateful); + let builtin_info = get_builtin_info f in + let has_builtin_info = builtin_info <> None in + group_has_builtin_info := !group_has_builtin_info || has_builtin_info; match f.body with | None -> - (* Opaque function: we consider they fail by default *) - obj#may_fail true; - stateful := (not f.is_global_decl_body) && use_state + let info_can_fail, info_stateful = + match builtin_info with + | None -> (true, use_state) + | Some { can_fail; stateful } -> (can_fail, stateful) + in + obj#may_fail info_can_fail; + obj#maybe_stateful + (if f.is_global_decl_body then false + else if not use_state then false + else info_stateful) | Some body -> obj#visit_statement () body.body in List.iter visit_fun d; @@ -114,12 +143,17 @@ let analyze_module (m : crate) (funs_map : fun_decl FunDeclId.Map.t) * groups containing globals contain exactly one declaration *) let is_global_decl_body = List.exists (fun f -> f.is_global_decl_body) d in assert ((not is_global_decl_body) || List.length d = 1); + assert ((not !group_has_builtin_info) || List.length d = 1); (* We ignore on purpose functions that cannot fail and consider they *can* * fail: the result of the analysis is not used yet to adjust the translation * so that the functions which syntactically can't fail don't use an error monad. - * However, we do keep the result of the analysis for global bodies. + * However, we do keep the result of the analysis for global bodies and for + * builtin functions which are marked as non-fallible. * *) - can_fail := (not is_global_decl_body) || !can_fail; + can_fail := + if is_global_decl_body then !can_fail + else if !group_has_builtin_info then !can_fail + else true; { can_fail = !can_fail; stateful = !stateful; @@ -141,7 +175,8 @@ let analyze_module (m : crate) (funs_map : fun_decl FunDeclId.Map.t) let rec analyze_decl_groups (decls : declaration_group list) : unit = match decls with | [] -> () - | Type _ :: decls' -> analyze_decl_groups decls' + | (Type _ | TraitDecl _ | TraitImpl _) :: decls' -> + analyze_decl_groups decls' | Fun decl :: decls' -> analyze_fun_decl_group decl; analyze_decl_groups decls' diff --git a/compiler/Interpreter.ml b/compiler/Interpreter.ml index 154c5a21..24ff4808 100644 --- a/compiler/Interpreter.ml +++ b/compiler/Interpreter.ml @@ -12,55 +12,165 @@ module SA = SymbolicAst (** The local logger *) let log = L.interpreter_log -let compute_type_fun_global_contexts (m : A.crate) : - C.type_context * C.fun_context * C.global_context = - let type_decls_list, _, _ = split_declarations m.declarations in +let compute_contexts (m : A.crate) : C.decls_ctx = + let type_decls_list, _, _, _, _ = split_declarations m.declarations in let type_decls = m.types in let fun_decls = m.functions in let global_decls = m.globals in - let type_decls_groups, _funs_defs_groups, _globals_defs_groups = + let trait_decls = m.trait_decls in + let trait_impls = m.trait_impls in + let type_decls_groups, _, _, _, _ = split_declarations_to_group_maps m.declarations in let type_infos = TypesAnalysis.analyze_type_declarations type_decls type_decls_list in - let type_context = { C.type_decls_groups; type_decls; type_infos } in - let fun_context = { C.fun_decls } in - let global_context = { C.global_decls } in - (type_context, fun_context, global_context) - -let initialize_eval_context (type_context : C.type_context) - (fun_context : C.fun_context) (global_context : C.global_context) - (region_groups : T.RegionGroupId.id list) (type_vars : T.type_var list) - (const_generic_vars : T.const_generic_var list) : C.eval_ctx = - C.reset_global_counters (); - { - C.type_context; - C.fun_context; - C.global_context; - C.region_groups; - C.type_vars; - C.const_generic_vars; - C.env = [ C.Frame ]; - C.ended_regions = T.RegionId.Set.empty; - } + let type_ctx = { C.type_decls_groups; type_decls; type_infos } in + let fun_infos = + FunsAnalysis.analyze_module m fun_decls global_decls !Config.use_state + in + let fun_ctx = { C.fun_decls; fun_infos } in + let global_ctx = { C.global_decls } in + let trait_decls_ctx = { C.trait_decls } in + let trait_impls_ctx = { C.trait_impls } in + { C.type_ctx; fun_ctx; global_ctx; trait_decls_ctx; trait_impls_ctx } + +(** Small helper. + + Normalize an instantiated function signature provided we used this signature + to compute a normalization map (for the associated types) and that we added + it in the context. + *) +let normalize_inst_fun_sig (ctx : C.eval_ctx) (sg : A.inst_fun_sig) : + A.inst_fun_sig = + let { A.regions_hierarchy = _; trait_type_constraints = _; inputs; output } = + sg + in + let norm = AssociatedTypes.ctx_normalize_rty ctx in + let inputs = List.map norm inputs in + let output = norm output in + { sg with A.inputs; output } + +(** Instantiate a function signature for a symbolic execution. + + We return a new context because we compute and add the type normalization + map in the same step. + + **WARNING**: this doesn't normalize the types. This step has to be done + separately. Remark: we need to normalize essentially because of the where + clauses (we are not considering a function call, so we don't need to + normalize because a trait clause was instantiated with a specific trait ref). + *) +let symbolic_instantiate_fun_sig (ctx : C.eval_ctx) (sg : A.fun_sig) + (kind : A.fun_kind) : C.eval_ctx * A.inst_fun_sig = + let tr_self = + match kind with + | RegularKind | TraitMethodImpl _ -> T.UnknownTrait __FUNCTION__ + | TraitMethodDecl _ | TraitMethodProvided _ -> T.Self + in + let generics = + let { T.regions; types; const_generics; trait_clauses } = sg.generics in + let regions = List.map (fun _ -> T.Erased) regions in + let types = List.map (fun (v : T.type_var) -> T.TypeVar v.T.index) types in + let const_generics = + List.map + (fun (v : T.const_generic_var) -> T.ConstGenericVar v.T.index) + const_generics + in + (* Annoying that we have to generate this substitution here *) + let r_subst _ = raise (Failure "Unexpected region") in + let ty_subst = Subst.make_type_subst_from_vars sg.generics.types types in + let cg_subst = + Subst.make_const_generic_subst_from_vars sg.generics.const_generics + const_generics + in + (* TODO: some clauses may use the types of other clauses, so we may have to + reorder them. + + Example: + If in Rust we write: + {[ + pub fn use_get<'a, T: Get>(x: &'a mut T) -> u32 + where + T::Item: ToU32, + { + x.get().to_u32() + } + ]} + + In LLBC we get: + {[ + fn demo::use_get<'a, T>(@1: &'a mut (T)) -> u32 + where + [@TraitClause0]: demo::Get<T>, + [@TraitClause1]: demo::ToU32<@TraitClause0::Item>, // HERE + { + ... // Omitted + } + ]} + *) + (* We will need to update the trait refs map while we perform the instantiations *) + let mk_tr_subst + (tr_map : T.erased_region T.trait_instance_id T.TraitClauseId.Map.t) + clause_id : T.erased_region T.trait_instance_id = + match T.TraitClauseId.Map.find_opt clause_id tr_map with + | Some tr -> tr + | None -> raise (Failure "Local trait clause not found") + in + let mk_subst tr_map = + let tr_subst = mk_tr_subst tr_map in + { Subst.r_subst; ty_subst; cg_subst; tr_subst; tr_self } + in + let _, trait_refs = + List.fold_left_map + (fun tr_map (c : T.trait_clause) -> + let subst = mk_subst tr_map in + let { T.trait_id = trait_decl_id; generics; _ } = c in + let generics = Subst.generic_args_substitute subst generics in + let trait_decl_ref = { T.trait_decl_id; decl_generics = generics } in + (* Note that because we directly refer to the clause, we give it + empty generics *) + let trait_id = T.Clause c.clause_id in + let trait_ref = + { + T.trait_id; + generics = TypesUtils.mk_empty_generic_args; + trait_decl_ref; + } + in + (* Update the traits map *) + let tr_map = T.TraitClauseId.Map.add c.T.clause_id trait_id tr_map in + (tr_map, trait_ref)) + T.TraitClauseId.Map.empty trait_clauses + in + { T.regions; types; const_generics; trait_refs } + in + let inst_sg = instantiate_fun_sig ctx generics tr_self sg in + (* Compute the normalization maps *) + let ctx = + AssociatedTypes.ctx_add_norm_trait_types_from_preds ctx + inst_sg.trait_type_constraints + in + (* Normalize the signature *) + let inst_sg = normalize_inst_fun_sig ctx inst_sg in + (* Return *) + (ctx, inst_sg) (** Initialize an evaluation context to execute a function. - Introduces local variables initialized in the following manner: - - input arguments are initialized as symbolic values - - the remaining locals are initialized as [⊥] - Abstractions are introduced for the regions present in the function - signature. - - We return: - - the initialized evaluation context - - the list of symbolic values introduced for the input values - - the instantiated function signature + Introduces local variables initialized in the following manner: + - input arguments are initialized as symbolic values + - the remaining locals are initialized as [⊥] + Abstractions are introduced for the regions present in the function + signature. + + We return: + - the initialized evaluation context + - the list of symbolic values introduced for the input values + - the instantiated function signature *) -let initialize_symbolic_context_for_fun (type_context : C.type_context) - (fun_context : C.fun_context) (global_context : C.global_context) - (fdef : A.fun_decl) : C.eval_ctx * V.symbolic_value list * A.inst_fun_sig = +let initialize_symbolic_context_for_fun (ctx : C.decls_ctx) (fdef : A.fun_decl) + : C.eval_ctx * V.symbolic_value list * A.inst_fun_sig = (* The abstractions are not initialized the same way as for function * calls: they contain *loan* projectors, because they "provide" us * with the input values (which behave as if they had been returned @@ -78,19 +188,15 @@ let initialize_symbolic_context_for_fun (type_context : C.type_context) List.map (fun (g : T.region_var_group) -> g.id) sg.regions_hierarchy in let ctx = - initialize_eval_context type_context fun_context global_context - region_groups sg.type_params sg.const_generic_params + initialize_eval_context ctx region_groups sg.generics.types + sg.generics.const_generics in - (* Instantiate the signature *) - let type_params = - List.map (fun (v : T.type_var) -> T.TypeVar v.T.index) sg.type_params + (* Instantiate the signature. This updates the context because we compute + at the same time the normalization map for the associated types. + *) + let ctx, inst_sg = + symbolic_instantiate_fun_sig ctx fdef.signature fdef.kind in - let cg_params = - List.map - (fun (v : T.const_generic_var) -> T.ConstGenericVar v.T.index) - sg.const_generic_params - in - let inst_sg = instantiate_fun_sig type_params cg_params sg in (* Create fresh symbolic values for the inputs *) let input_svs = List.map (fun ty -> mk_fresh_symbolic_value V.SynthInput ty) inst_sg.inputs @@ -165,15 +271,9 @@ let evaluate_function_symbolic_synthesize_backward_from_return * an instantiation of the signature, so that we use fresh * region ids for the return abstractions. *) let sg = fdef.signature in - let type_params = - List.map (fun (v : T.type_var) -> T.TypeVar v.T.index) sg.type_params - in - let cg_params = - List.map - (fun (v : T.const_generic_var) -> T.ConstGenericVar v.T.index) - sg.const_generic_params + let _, ret_inst_sg = + symbolic_instantiate_fun_sig ctx fdef.signature fdef.kind in - let ret_inst_sg = instantiate_fun_sig type_params cg_params sg in let ret_rty = ret_inst_sg.output in (* Move the return value out of the return variable *) let pop_return_value = is_regular_return in @@ -347,19 +447,14 @@ let evaluate_function_symbolic_synthesize_backward_from_return for the synthesis) - the symbolic AST generated by the symbolic execution *) -let evaluate_function_symbolic (synthesize : bool) - (type_context : C.type_context) (fun_context : C.fun_context) - (global_context : C.global_context) (fdef : A.fun_decl) : - V.symbolic_value list * SA.expression option = +let evaluate_function_symbolic (synthesize : bool) (ctx : C.decls_ctx) + (fdef : A.fun_decl) : V.symbolic_value list * SA.expression option = (* Debug *) let name_to_string () = Print.fun_name_to_string fdef.A.name in log#ldebug (lazy ("evaluate_function_symbolic: " ^ name_to_string ())); (* Create the evaluation context *) - let ctx, input_svs, inst_sg = - initialize_symbolic_context_for_fun type_context fun_context global_context - fdef - in + let ctx, input_svs, inst_sg = initialize_symbolic_context_for_fun ctx fdef in (* Create the continuation to finish the evaluation *) let config = C.mk_config C.SymbolicMode in @@ -488,7 +583,8 @@ module Test = struct (** Test a unit function (taking no arguments) by evaluating it in an empty environment. *) - let test_unit_function (crate : A.crate) (fid : A.FunDeclId.id) : unit = + let test_unit_function (crate : A.crate) (decls_ctx : C.decls_ctx) + (fid : A.FunDeclId.id) : unit = (* Retrieve the function declaration *) let fdef = A.FunDeclId.Map.find fid crate.functions in let body = Option.get fdef.body in @@ -498,17 +594,11 @@ module Test = struct (lazy ("test_unit_function: " ^ Print.fun_name_to_string fdef.A.name)); (* Sanity check - *) - assert (List.length fdef.A.signature.region_params = 0); - assert (List.length fdef.A.signature.type_params = 0); + assert (fdef.A.signature.generics = TypesUtils.mk_empty_generic_params); assert (body.A.arg_count = 0); (* Create the evaluation context *) - let type_context, fun_context, global_context = - compute_type_fun_global_contexts crate - in - let ctx = - initialize_eval_context type_context fun_context global_context [] [] [] - in + let ctx = initialize_eval_context decls_ctx [] [] [] in (* Insert the (uninitialized) local variables *) let ctx = C.ctx_push_uninitialized_vars ctx body.A.locals in @@ -536,9 +626,7 @@ module Test = struct (no parameters, no arguments) - TODO: move *) let fun_decl_is_transparent_unit (def : A.fun_decl) : bool = Option.is_some def.body - && def.A.signature.region_params = [] - && def.A.signature.type_params = [] - && def.A.signature.const_generic_params = [] + && def.A.signature.generics = TypesUtils.mk_empty_generic_params && def.A.signature.inputs = [] (** Test all the unit functions in a list of function definitions *) @@ -548,24 +636,9 @@ module Test = struct (fun _ -> fun_decl_is_transparent_unit) crate.functions in + let decls_ctx = compute_contexts crate in let test_unit_fun _ (def : A.fun_decl) : unit = - test_unit_function crate def.A.def_id + test_unit_function crate decls_ctx def.A.def_id in A.FunDeclId.Map.iter test_unit_fun unit_funs - - (** Execute the symbolic interpreter on a function. *) - let test_function_symbolic (synthesize : bool) (type_context : C.type_context) - (fun_context : C.fun_context) (global_context : C.global_context) - (fdef : A.fun_decl) : unit = - (* Debug *) - log#ldebug - (lazy ("test_function_symbolic: " ^ Print.fun_name_to_string fdef.A.name)); - - (* Evaluate *) - let _ = - evaluate_function_symbolic synthesize type_context fun_context - global_context fdef - in - - () end diff --git a/compiler/InterpreterBorrows.ml b/compiler/InterpreterBorrows.ml index 4d67a4e4..e97795a1 100644 --- a/compiler/InterpreterBorrows.ml +++ b/compiler/InterpreterBorrows.ml @@ -452,7 +452,8 @@ let give_back_symbolic_value (_config : C.config) | V.SynthInputGivenBack | SynthRetGivenBack | FunCallGivenBack | LoopGivenBack -> () - | FunCallRet | SynthInput | Global | LoopOutput | LoopJoin | Aggregate -> + | FunCallRet | SynthInput | Global | LoopOutput | LoopJoin | Aggregate + | ConstGeneric | TraitConst -> raise (Failure "Unreachable")); (* Store the given-back value as a meta-value for synthesis purposes *) let mv = nsv in diff --git a/compiler/InterpreterBorrowsCore.ml b/compiler/InterpreterBorrowsCore.ml index bf083aa4..e7da045c 100644 --- a/compiler/InterpreterBorrowsCore.ml +++ b/compiler/InterpreterBorrowsCore.ml @@ -100,15 +100,18 @@ let rec compare_rtys (default : bool) (combine : bool -> bool -> bool) (compare_regions : T.RegionId.id T.region -> T.RegionId.id T.region -> bool) (ty1 : T.rty) (ty2 : T.rty) : bool = let compare = compare_rtys default combine compare_regions in + (* Normalize the associated types *) match (ty1, ty2) with | T.Literal lit1, T.Literal lit2 -> assert (lit1 = lit2); default - | T.Adt (id1, regions1, tys1, cgs1), T.Adt (id2, regions2, tys2, cgs2) -> + | T.Adt (id1, generics1), T.Adt (id2, generics2) -> assert (id1 = id2); (* There are no regions in the const generics, so we ignore them, but we still check they are the same, for sanity *) - assert (cgs1 = cgs2); + assert (generics1.const_generics = generics2.const_generics); + + (* We also ignore the trait refs *) (* The check for the ADTs is very crude: we simply compare the arguments * two by two. @@ -123,14 +126,14 @@ let rec compare_rtys (default : bool) (combine : bool -> bool -> bool) * this check would still be a reasonable conservative approximation. *) (* Check the region parameters *) - let regions = List.combine regions1 regions2 in + let regions = List.combine generics1.regions generics2.regions in let params_b = List.fold_left (fun b (r1, r2) -> combine b (compare_regions r1 r2)) default regions in (* Check the type parameters *) - let tys = List.combine tys1 tys2 in + let tys = List.combine generics1.types generics2.types in let tys_b = List.fold_left (fun b (ty1, ty2) -> combine b (compare ty1 ty2)) @@ -150,6 +153,11 @@ let rec compare_rtys (default : bool) (combine : bool -> bool -> bool) | T.TypeVar id1, T.TypeVar id2 -> assert (id1 = id2); default + | T.TraitType _, T.TraitType _ -> + (* The types should have been normalized. If after normalization we + get trait types, we can consider them as variables *) + assert (ty1 = ty2); + default | _ -> log#lerror (lazy diff --git a/compiler/InterpreterExpansion.ml b/compiler/InterpreterExpansion.ml index 81e73e3e..b267bb51 100644 --- a/compiler/InterpreterExpansion.ml +++ b/compiler/InterpreterExpansion.ml @@ -9,6 +9,7 @@ module V = Values module E = Expressions module C = Contexts module Subst = Substitute +module Assoc = AssociatedTypes module L = Logging open TypesUtils module Inv = Invariants @@ -204,7 +205,7 @@ let apply_symbolic_expansion_non_borrow (config : C.config) apply_symbolic_expansion_to_avalues config allow_reborrows original_sv expansion ctx -(** Compute the expansion of a non-assumed (i.e.: not [Option], [Box], etc.) +(** Compute the expansion of a non-assumed (i.e.: not [Box], etc.) adt value. The function might return a list of values if the symbolic value to expand @@ -214,18 +215,15 @@ let apply_symbolic_expansion_non_borrow (config : C.config) doesn't allow the expansion of enumerations *containing several variants*. *) let compute_expanded_symbolic_non_assumed_adt_value (expand_enumerations : bool) - (kind : V.sv_kind) (def_id : T.TypeDeclId.id) - (regions : T.RegionId.id T.region list) (types : T.rty list) - (cgs : T.const_generic list) (ctx : C.eval_ctx) : V.symbolic_expansion list - = + (kind : V.sv_kind) (def_id : T.TypeDeclId.id) (generics : T.rgeneric_args) + (ctx : C.eval_ctx) : V.symbolic_expansion list = (* Lookup the definition and check if it is an enumeration with several * variants *) let def = C.ctx_lookup_type_decl ctx def_id in - assert (List.length regions = List.length def.T.region_params); + assert (List.length generics.regions = List.length def.T.generics.regions); (* Retrieve, for every variant, the list of its instantiated field types *) let variants_fields_types = - Subst.type_decl_get_instantiated_variants_fields_rtypes def regions types - cgs + Assoc.type_decl_get_inst_norm_variants_fields_rtypes ctx def generics in (* Check if there is strictly more than one variant *) if List.length variants_fields_types > 1 && not expand_enumerations then @@ -243,17 +241,6 @@ let compute_expanded_symbolic_non_assumed_adt_value (expand_enumerations : bool) (* Initialize all the expanded values of all the variants *) List.map initialize variants_fields_types -(** Compute the expansion of an Option value. - *) -let compute_expanded_symbolic_option_value (expand_enumerations : bool) - (kind : V.sv_kind) (ty : T.rty) : V.symbolic_expansion list = - assert expand_enumerations; - let some_se = - V.SeAdt (Some T.option_some_id, [ mk_fresh_symbolic_value kind ty ]) - in - let none_se = V.SeAdt (Some T.option_none_id, []) in - [ none_se; some_se ] - let compute_expanded_symbolic_tuple_value (kind : V.sv_kind) (field_types : T.rty list) : V.symbolic_expansion = (* Generate the field values *) @@ -280,17 +267,14 @@ let compute_expanded_symbolic_box_value (kind : V.sv_kind) (boxed_ty : T.rty) : doesn't allow the expansion of enumerations *containing several variants*. *) let compute_expanded_symbolic_adt_value (expand_enumerations : bool) - (kind : V.sv_kind) (adt_id : T.type_id) - (regions : T.RegionId.id T.region list) (types : T.rty list) - (cgs : T.const_generic list) (ctx : C.eval_ctx) : V.symbolic_expansion list - = - match (adt_id, regions, types) with + (kind : V.sv_kind) (adt_id : T.type_id) (generics : T.rgeneric_args) + (ctx : C.eval_ctx) : V.symbolic_expansion list = + match (adt_id, generics.regions, generics.types) with | T.AdtId def_id, _, _ -> compute_expanded_symbolic_non_assumed_adt_value expand_enumerations kind - def_id regions types cgs ctx - | T.Tuple, [], _ -> [ compute_expanded_symbolic_tuple_value kind types ] - | T.Assumed T.Option, [], [ ty ] -> - compute_expanded_symbolic_option_value expand_enumerations kind ty + def_id generics ctx + | T.Tuple, [], _ -> + [ compute_expanded_symbolic_tuple_value kind generics.types ] | T.Assumed T.Box, [], [ boxed_ty ] -> [ compute_expanded_symbolic_box_value kind boxed_ty ] | _ -> @@ -543,12 +527,12 @@ let expand_symbolic_value_no_branching (config : C.config) fun cf ctx -> match rty with (* ADTs *) - | T.Adt (adt_id, regions, types, cgs) -> + | T.Adt (adt_id, generics) -> (* Compute the expanded value *) let allow_branching = false in let seel = compute_expanded_symbolic_adt_value allow_branching sv.sv_kind adt_id - regions types cgs ctx + generics ctx in (* There should be exacly one branch *) let see = Collections.List.to_cons_nil seel in @@ -600,12 +584,12 @@ let expand_symbolic_adt (config : C.config) (sv : V.symbolic_value) (* Execute *) match rty with (* ADTs *) - | T.Adt (adt_id, regions, types, cgs) -> + | T.Adt (adt_id, generics) -> let allow_branching = true in (* Compute the expanded value *) let seel = compute_expanded_symbolic_adt_value allow_branching sv.sv_kind adt_id - regions types cgs ctx + generics ctx in (* Apply *) let seel = List.map (fun see -> (Some see, cf_branches)) seel in @@ -679,7 +663,7 @@ let greedy_expand_symbolics_with_borrows (config : C.config) : cm_fun = ^ symbolic_value_to_string ctx sv)); let cc : cm_fun = match sv.V.sv_ty with - | T.Adt (AdtId def_id, _, _, _) -> + | T.Adt (AdtId def_id, _) -> (* {!expand_symbolic_value_no_branching} checks if there are branchings, * but we prefer to also check it here - this leads to cleaner messages * and debugging *) @@ -704,16 +688,17 @@ let greedy_expand_symbolics_with_borrows (config : C.config) : cm_fun = [config]): " ^ Print.name_to_string def.name)) else expand_symbolic_value_no_branching config sv None - | T.Adt ((Tuple | Assumed Box), _, _, _) | T.Ref (_, _, _) -> + | T.Adt ((Tuple | Assumed Box), _) | T.Ref (_, _, _) -> (* Ok *) expand_symbolic_value_no_branching config sv None - | T.Adt (Assumed (Vec | Option | Array | Slice | Str | Range), _, _, _) - -> + | T.Adt (Assumed (Array | Slice | Str), _) -> (* We can't expand those *) raise (Failure "Attempted to greedily expand an ADT which can't be expanded ") - | T.TypeVar _ | T.Literal _ | Never -> raise (Failure "Unreachable") + | T.TypeVar _ | T.Literal _ | Never | T.TraitType _ | T.Arrow _ + | T.RawPtr _ -> + raise (Failure "Unreachable") in (* Compose and continue *) comp cc expand cf ctx diff --git a/compiler/InterpreterExpressions.ml b/compiler/InterpreterExpressions.ml index 8b2070c6..245f3b77 100644 --- a/compiler/InterpreterExpressions.ml +++ b/compiler/InterpreterExpressions.ml @@ -7,6 +7,7 @@ module E = Expressions open Utils module C = Contexts module Subst = Substitute +module Assoc = AssociatedTypes module L = Logging open TypesUtils open ValuesUtils @@ -141,11 +142,19 @@ let rec copy_value (allow_adt_copy : bool) (config : C.config) | V.Adt av -> (* Sanity check *) (match v.V.ty with - | T.Adt (T.Assumed (T.Box | Vec), _, _, _) -> + | T.Adt (T.Assumed T.Box, _) -> raise (Failure "Can't copy an assumed value other than Option") - | T.Adt (T.AdtId _, _, _, _) -> assert allow_adt_copy - | T.Adt ((T.Assumed Option | T.Tuple), _, _, _) -> () (* Ok *) - | T.Adt (T.Assumed (Slice | T.Array), [], [ ty ], []) -> + | T.Adt (T.AdtId _, _) as ty -> + assert (allow_adt_copy || ty_is_primitively_copyable ty) + | T.Adt (T.Tuple, _) -> () (* Ok *) + | T.Adt + ( T.Assumed (Slice | T.Array), + { + regions = []; + types = [ ty ]; + const_generics = []; + trait_refs = []; + } ) -> assert (ty_is_primitively_copyable ty) | _ -> raise (Failure "Unreachable")); let ctx, fields = @@ -230,17 +239,16 @@ let prepare_eval_operand_reorganize (config : C.config) (op : E.operand) : let prepare : cm_fun = fun cf ctx -> match op with - | Expressions.Constant (ty, cv) -> + | E.Constant _ -> (* No need to reorganize the context *) - literal_to_typed_value (TypesUtils.ty_as_literal ty) cv |> ignore; cf ctx - | Expressions.Copy p -> + | E.Copy p -> (* Access the value *) let access = Read in (* Expand the symbolic values, if necessary *) let expand_prim_copy = true in access_rplace_reorganize config expand_prim_copy access p cf ctx - | Expressions.Move p -> + | E.Move p -> (* Access the value *) let access = Move in let expand_prim_copy = false in @@ -260,9 +268,71 @@ let eval_operand_no_reorganize (config : C.config) (op : E.operand) ^ "\n- ctx:\n" ^ eval_ctx_to_string ctx ^ "\n")); (* Evaluate *) match op with - | Expressions.Constant (ty, cv) -> - cf (literal_to_typed_value (TypesUtils.ty_as_literal ty) cv) ctx - | Expressions.Copy p -> + | E.Constant cv -> ( + match cv.value with + | E.CLiteral lit -> + cf (literal_to_typed_value (TypesUtils.ty_as_literal cv.ty) lit) ctx + | E.CTraitConst (trait_ref, generics, const_name) -> ( + assert (generics = TypesUtils.mk_empty_generic_args); + match trait_ref.trait_id with + | T.TraitImpl _ -> + (* This shouldn't happen: if we refer to a concrete implementation, we + should directly refer to the top-level constant *) + raise (Failure "Unreachable") + | _ -> ( + (* We refer to a constant defined in a local clause: simply + introduce a fresh symbolic value *) + let ctx0 = ctx in + (* Lookup the trait declaration to retrieve the type of the symbolic value *) + let trait_decl = + C.ctx_lookup_trait_decl ctx + trait_ref.trait_decl_ref.trait_decl_id + in + let _, (ty, _) = + List.find (fun (name, _) -> name = const_name) trait_decl.consts + in + (* Introduce a fresh symbolic value *) + let v = mk_fresh_symbolic_typed_value_from_ety V.TraitConst ty in + (* Continue the evaluation *) + let e = cf v ctx in + (* We have to wrap the generated expression *) + match e with + | None -> None + | Some e -> + Some + (SymbolicAst.IntroSymbolic + ( ctx0, + None, + value_as_symbolic v.value, + SymbolicAst.TraitConstValue + (trait_ref, generics, const_name), + e )))) + | E.CVar vid -> ( + let ctx0 = ctx in + (* Lookup the const generic value *) + let cv = C.ctx_lookup_const_generic_value ctx vid in + (* Copy the value *) + let allow_adt_copy = false in + let ctx, v = copy_value allow_adt_copy config ctx cv in + (* Continue *) + let e = cf v ctx in + (* We have to wrap the generated expression *) + match e with + | None -> None + | Some e -> + (* If we are synthesizing a symbolic AST, it means that we are in symbolic + mode: the value of the const generic is necessarily symbolic. *) + assert (is_symbolic cv.V.value); + (* *) + Some + (SymbolicAst.IntroSymbolic + ( ctx0, + None, + value_as_symbolic v.value, + SymbolicAst.ConstGenericValue vid, + e ))) + | E.CFnPtr _ -> raise (Failure "TODO")) + | E.Copy p -> (* Access the value *) let access = Read in let cc = read_place access p in @@ -283,7 +353,7 @@ let eval_operand_no_reorganize (config : C.config) (op : E.operand) in (* Compose and apply *) comp cc copy cf ctx - | Expressions.Move p -> + | E.Move p -> (* Access the value *) let access = Move in let cc = read_place access p in @@ -358,7 +428,7 @@ let eval_unary_op_concrete (config : C.config) (unop : E.unop) (op : E.operand) match mk_scalar sv.int_ty i with | Error _ -> cf (Error EPanic) | Ok sv -> cf (Ok { v with V.value = V.Literal (PV.Scalar sv) })) - | E.Cast (src_ty, tgt_ty), V.Literal (PV.Scalar sv) -> ( + | E.Cast (E.CastInteger (src_ty, tgt_ty)), V.Literal (PV.Scalar sv) -> ( assert (src_ty = sv.int_ty); let i = sv.PV.value in match mk_scalar tgt_ty i with @@ -384,7 +454,7 @@ let eval_unary_op_symbolic (config : C.config) (unop : E.unop) (op : E.operand) match (unop, v.V.ty) with | E.Not, (T.Literal Bool as lty) -> lty | E.Neg, (T.Literal (Integer _) as lty) -> lty - | E.Cast (_, tgt_ty), _ -> T.Literal (Integer tgt_ty) + | E.Cast (E.CastInteger (_, tgt_ty)), _ -> T.Literal (Integer tgt_ty) | _ -> raise (Failure "Invalid input for unop") in let res_sv = @@ -653,73 +723,46 @@ let eval_rvalue_aggregate (config : C.config) fun ctx -> (* Match on the aggregate kind *) match aggregate_kind with - | E.AggregatedTuple -> - let tys = List.map (fun (v : V.typed_value) -> v.V.ty) values in - let v = V.Adt { variant_id = None; field_values = values } in - let ty = T.Adt (T.Tuple, [], tys, []) in - let aggregated : V.typed_value = { V.value = v; ty } in - (* Call the continuation *) - cf aggregated ctx - | E.AggregatedOption (variant_id, ty) -> - (* Sanity check *) - if variant_id = T.option_none_id then assert (values = []) - else if variant_id = T.option_some_id then - assert (List.length values = 1) - else raise (Failure "Unreachable"); - (* Construt the value *) - let aty = T.Adt (T.Assumed T.Option, [], [ ty ], []) in - let av : V.adt_value = - { V.variant_id = Some variant_id; V.field_values = values } - in - let aggregated : V.typed_value = { V.value = Adt av; ty = aty } in - (* Call the continuation *) - cf aggregated ctx - | E.AggregatedAdt (def_id, opt_variant_id, regions, types, cgs) -> - (* Sanity checks *) - let type_decl = C.ctx_lookup_type_decl ctx def_id in - assert (List.length type_decl.region_params = List.length regions); - let expected_field_types = - Subst.ctx_adt_get_instantiated_field_etypes ctx def_id opt_variant_id - types cgs - in - assert ( - expected_field_types - = List.map (fun (v : V.typed_value) -> v.V.ty) values); - (* Construct the value *) - let av : V.adt_value = - { V.variant_id = opt_variant_id; V.field_values = values } - in - let aty = T.Adt (T.AdtId def_id, regions, types, cgs) in - let aggregated : V.typed_value = { V.value = Adt av; ty = aty } in - (* Call the continuation *) - cf aggregated ctx - | E.AggregatedRange ety -> - (* There should be two fields exactly *) - let v0, v1 = - match values with - | [ v0; v1 ] -> (v0, v1) - | _ -> raise (Failure "Unreachable") - in - (* Ranges are parametric over the type of indices. For now we only - support scalars, which can be of any type *) - assert (literal_type_is_integer (ty_as_literal ety)); - assert (v0.ty = ety); - assert (v1.ty = ety); - (* Construct the value *) - let av : V.adt_value = - { V.variant_id = None; V.field_values = values } - in - let aty = T.Adt (T.Assumed T.Range, [], [ ety ], []) in - let aggregated : V.typed_value = { V.value = Adt av; ty = aty } in - (* Call the continuation *) - cf aggregated ctx + | E.AggregatedAdt (type_id, opt_variant_id, generics) -> ( + match type_id with + | Tuple -> + let tys = List.map (fun (v : V.typed_value) -> v.V.ty) values in + let v = V.Adt { variant_id = None; field_values = values } in + let generics = TypesUtils.mk_generic_args [] tys [] [] in + let ty = T.Adt (T.Tuple, generics) in + let aggregated : V.typed_value = { V.value = v; ty } in + (* Call the continuation *) + cf aggregated ctx + | AdtId def_id -> + (* Sanity checks *) + let type_decl = C.ctx_lookup_type_decl ctx def_id in + assert ( + List.length type_decl.generics.regions + = List.length generics.regions); + let expected_field_types = + Assoc.ctx_adt_get_inst_norm_field_etypes ctx def_id opt_variant_id + generics + in + assert ( + expected_field_types + = List.map (fun (v : V.typed_value) -> v.V.ty) values); + (* Construct the value *) + let av : V.adt_value = + { V.variant_id = opt_variant_id; V.field_values = values } + in + let aty = T.Adt (T.AdtId def_id, generics) in + let aggregated : V.typed_value = { V.value = Adt av; ty = aty } in + (* Call the continuation *) + cf aggregated ctx + | Assumed _ -> raise (Failure "Unreachable")) | E.AggregatedArray (ety, cg) -> ( (* Sanity check: all the values have the proper type *) assert (List.for_all (fun (v : V.typed_value) -> v.V.ty = ety) values); (* Sanity check: the number of values is consistent with the length *) let len = (literal_as_scalar (const_generic_as_literal cg)).value in assert (len = Z.of_int (List.length values)); - let ty = T.Adt (T.Assumed T.Array, [], [ ety ], [ cg ]) in + let generics = TypesUtils.mk_generic_args [] [ ety ] [ cg ] [] in + let ty = T.Adt (T.Assumed T.Array, generics) in (* In order to generate a better AST, we introduce a symbolic value equal to the array. The reason is that otherwise, the array we introduce here might be duplicated in the generated @@ -752,7 +795,7 @@ let eval_rvalue_not_global (config : C.config) (rvalue : E.rvalue) (* Delegate to the proper auxiliary function *) match rvalue with | E.Use op -> comp_wrap (eval_operand config op) ctx - | E.Ref (p, bkind) -> comp_wrap (eval_rvalue_ref config p bkind) ctx + | E.RvRef (p, bkind) -> comp_wrap (eval_rvalue_ref config p bkind) ctx | E.UnaryOp (unop, op) -> eval_unary_op config unop op cf ctx | E.BinaryOp (binop, op1, op2) -> eval_binary_op config binop op1 op2 cf ctx | E.Aggregate (aggregate_kind, ops) -> diff --git a/compiler/InterpreterLoopsJoinCtxs.ml b/compiler/InterpreterLoopsJoinCtxs.ml index bf88e055..6d3ecb18 100644 --- a/compiler/InterpreterLoopsJoinCtxs.ml +++ b/compiler/InterpreterLoopsJoinCtxs.ml @@ -554,9 +554,15 @@ let join_ctxs (loop_id : V.LoopId.id) (fixed_ids : ids_sets) (ctx0 : C.eval_ctx) C.type_context; fun_context; global_context; + trait_decls_context; + trait_impls_context; region_groups; type_vars; const_generic_vars; + const_generic_vars_map; + norm_trait_etypes; + norm_trait_rtypes; + norm_trait_stypes; env = _; ended_regions = ended_regions0; } = @@ -566,9 +572,15 @@ let join_ctxs (loop_id : V.LoopId.id) (fixed_ids : ids_sets) (ctx0 : C.eval_ctx) C.type_context = _; fun_context = _; global_context = _; + trait_decls_context = _; + trait_impls_context = _; region_groups = _; type_vars = _; const_generic_vars = _; + const_generic_vars_map = _; + norm_trait_etypes = _; + norm_trait_rtypes = _; + norm_trait_stypes = _; env = _; ended_regions = ended_regions1; } = @@ -580,9 +592,15 @@ let join_ctxs (loop_id : V.LoopId.id) (fixed_ids : ids_sets) (ctx0 : C.eval_ctx) C.type_context; fun_context; global_context; + trait_decls_context; + trait_impls_context; region_groups; type_vars; const_generic_vars; + const_generic_vars_map; + norm_trait_etypes; + norm_trait_rtypes; + norm_trait_stypes; env; ended_regions; } diff --git a/compiler/InterpreterLoopsMatchCtxs.ml b/compiler/InterpreterLoopsMatchCtxs.ml index 9248e513..8cab546e 100644 --- a/compiler/InterpreterLoopsMatchCtxs.ml +++ b/compiler/InterpreterLoopsMatchCtxs.ml @@ -149,20 +149,25 @@ let rec match_types (match_distinct_types : 'r T.ty -> 'r T.ty -> 'r T.ty) (match_regions : 'r -> 'r -> 'r) (ty0 : 'r T.ty) (ty1 : 'r T.ty) : 'r T.ty = let match_rec = match_types match_distinct_types match_regions in match (ty0, ty1) with - | Adt (id0, regions0, tys0, cgs0), Adt (id1, regions1, tys1, cgs1) -> + | Adt (id0, generics0), Adt (id1, generics1) -> assert (id0 = id1); - assert (cgs0 = cgs1); + assert (generics0.const_generics = generics1.const_generics); + assert (generics0.trait_refs = generics1.trait_refs); let id = id0 in - let cgs = cgs1 in + let const_generics = generics1.const_generics in + let trait_refs = generics1.trait_refs in let regions = List.map (fun (id0, id1) -> match_regions id0 id1) - (List.combine regions0 regions1) + (List.combine generics0.regions generics1.regions) in - let tys = - List.map (fun (ty0, ty1) -> match_rec ty0 ty1) (List.combine tys0 tys1) + let types = + List.map + (fun (ty0, ty1) -> match_rec ty0 ty1) + (List.combine generics0.types generics1.types) in - Adt (id, regions, tys, cgs) + let generics = { T.regions; types; const_generics; trait_refs } in + Adt (id, generics) | TypeVar vid0, TypeVar vid1 -> assert (vid0 = vid1); let vid = vid0 in diff --git a/compiler/InterpreterPaths.ml b/compiler/InterpreterPaths.ml index 04dc8892..2a277c91 100644 --- a/compiler/InterpreterPaths.ml +++ b/compiler/InterpreterPaths.ml @@ -3,6 +3,7 @@ module V = Values module E = Expressions module C = Contexts module Subst = Substitute +module Assoc = AssociatedTypes module L = Logging open Cps open ValuesUtils @@ -95,16 +96,14 @@ let rec access_projection (access : projection_access) (ctx : C.eval_ctx) | pe :: p' -> ( (* Match on the projection element and the value *) match (pe, v.V.value, v.V.ty) with - | ( Field (((ProjAdt (_, _) | ProjOption _) as proj_kind), field_id), + | ( Field ((ProjAdt (_, _) as proj_kind), field_id), V.Adt adt, - T.Adt (type_id, _, _, _) ) -> ( + T.Adt (type_id, _) ) -> ( (* Check consistency *) (match (proj_kind, type_id) with | ProjAdt (def_id, opt_variant_id), T.AdtId def_id' -> assert (def_id = def_id'); assert (opt_variant_id = adt.variant_id) - | ProjOption variant_id, T.Assumed T.Option -> - assert (Some variant_id = adt.variant_id) | _ -> raise (Failure "Unreachable")); (* Actually project *) let fv = T.FieldId.nth adt.field_values field_id in @@ -119,8 +118,7 @@ let rec access_projection (access : projection_access) (ctx : C.eval_ctx) let updated = { v with value = nadt } in Ok (ctx, { res with updated })) (* Tuples *) - | Field (ProjTuple arity, field_id), V.Adt adt, T.Adt (T.Tuple, _, _, _) - -> ( + | Field (ProjTuple arity, field_id), V.Adt adt, T.Adt (T.Tuple, _) -> ( assert (arity = List.length adt.field_values); let fv = T.FieldId.nth adt.field_values field_id in (* Project *) @@ -136,7 +134,7 @@ let rec access_projection (access : projection_access) (ctx : C.eval_ctx) Ok (ctx, { res with updated }) (* If we reach Bottom, it may mean we need to expand an uninitialized * enumeration value *)) - | Field ((ProjAdt (_, _) | ProjTuple _ | ProjOption _), _), V.Bottom, _ -> + | Field ((ProjAdt (_, _) | ProjTuple _), _), V.Bottom, _ -> Error (FailBottom (1 + List.length p', pe, v.ty)) (* Symbolic value: needs to be expanded *) | _, Symbolic sp, _ -> @@ -145,9 +143,9 @@ let rec access_projection (access : projection_access) (ctx : C.eval_ctx) (* Box dereferencement *) | ( DerefBox, Adt { variant_id = None; field_values = [ bv ] }, - T.Adt (T.Assumed T.Box, _, _, _) ) -> ( - (* We allow moving inside of boxes. In practice, this kind of - * manipulations should happen only inside unsage code, so + T.Adt (T.Assumed T.Box, _) ) -> ( + (* We allow moving outside of boxes. In practice, this kind of + * manipulations should happen only inside unsafe code, so * it shouldn't happen due to user code, and we leverage it * when implementing box dereferencement for the concrete * interpreter *) @@ -357,45 +355,32 @@ let write_place (access : access_kind) (p : E.place) (nv : V.typed_value) | Error e -> raise (Failure ("Unreachable: " ^ show_path_fail_kind e)) | Ok ctx -> ctx -let compute_expanded_bottom_adt_value (tyctx : T.type_decl T.TypeDeclId.Map.t) +let compute_expanded_bottom_adt_value (ctx : C.eval_ctx) (def_id : T.TypeDeclId.id) (opt_variant_id : T.VariantId.id option) - (regions : T.erased_region list) (types : T.ety list) - (cgs : T.const_generic list) : V.typed_value = + (generics : T.egeneric_args) : V.typed_value = (* Lookup the definition and check if it is an enumeration - it should be an enumeration if and only if the projection element is a field projection with *some* variant id. Retrieve the list of fields at the same time. *) - let def = T.TypeDeclId.Map.find def_id tyctx in - assert (List.length regions = List.length def.T.region_params); + let def = C.ctx_lookup_type_decl ctx def_id in + assert (List.length generics.regions = List.length def.T.generics.regions); (* Compute the field types *) let field_types = - Subst.type_decl_get_instantiated_field_etypes def opt_variant_id types cgs + Assoc.type_decl_get_inst_norm_field_etypes ctx def opt_variant_id generics in (* Initialize the expanded value *) let fields = List.map mk_bottom field_types in let av = V.Adt { variant_id = opt_variant_id; field_values = fields } in - let ty = T.Adt (T.AdtId def_id, regions, types, cgs) in + let ty = T.Adt (T.AdtId def_id, generics) in { V.value = av; V.ty } -let compute_expanded_bottom_option_value (variant_id : T.VariantId.id) - (param_ty : T.ety) : V.typed_value = - (* Note that the variant can be [Some] or [None]: we expand bottom values - * when writing to fields or setting discriminants *) - let field_values = - if variant_id = T.option_some_id then [ mk_bottom param_ty ] - else if variant_id = T.option_none_id then [] - else raise (Failure "Unreachable") - in - let av = V.Adt { variant_id = Some variant_id; field_values } in - let ty = T.Adt (T.Assumed T.Option, [], [ param_ty ], []) in - { V.value = av; ty } - let compute_expanded_bottom_tuple_value (field_types : T.ety list) : V.typed_value = (* Generate the field values *) let fields = List.map mk_bottom field_types in let v = V.Adt { variant_id = None; field_values = fields } in - let ty = T.Adt (T.Tuple, [], field_types, []) in + let generics = TypesUtils.mk_generic_args [] field_types [] [] in + let ty = T.Adt (T.Tuple, generics) in { V.value = v; V.ty } (** Auxiliary helper to expand {!V.Bottom} values. @@ -447,19 +432,18 @@ let expand_bottom_value_from_projection (access : access_kind) (p : E.place) match (pe, ty) with (* "Regular" ADTs *) | ( Field (ProjAdt (def_id, opt_variant_id), _), - T.Adt (T.AdtId def_id', regions, types, cgs) ) -> + T.Adt (T.AdtId def_id', generics) ) -> assert (def_id = def_id'); - compute_expanded_bottom_adt_value ctx.type_context.type_decls def_id - opt_variant_id regions types cgs - (* Option *) - | ( Field (ProjOption variant_id, _), - T.Adt (T.Assumed T.Option, [], [ ty ], []) ) -> - compute_expanded_bottom_option_value variant_id ty + compute_expanded_bottom_adt_value ctx def_id opt_variant_id generics (* Tuples *) - | Field (ProjTuple arity, _), T.Adt (T.Tuple, [], tys, []) -> - assert (arity = List.length tys); + | ( Field (ProjTuple arity, _), + T.Adt + ( T.Tuple, + { T.regions = []; types; const_generics = []; trait_refs = [] } ) ) + -> + assert (arity = List.length types); (* Generate the field values *) - compute_expanded_bottom_tuple_value tys + compute_expanded_bottom_tuple_value types | _ -> raise (Failure diff --git a/compiler/InterpreterPaths.mli b/compiler/InterpreterPaths.mli index 4a9f3b41..0ff8063f 100644 --- a/compiler/InterpreterPaths.mli +++ b/compiler/InterpreterPaths.mli @@ -3,6 +3,7 @@ module V = Values module E = Expressions module C = Contexts module Subst = Substitute +module Assoc = AssociatedTypes module L = Logging open Cps open InterpreterExpansion @@ -56,18 +57,12 @@ val compute_expanded_bottom_tuple_value : T.ety list -> V.typed_value (** Compute an expanded ADT ⊥ value *) val compute_expanded_bottom_adt_value : - T.type_decl T.TypeDeclId.Map.t -> + C.eval_ctx -> T.TypeDeclId.id -> T.VariantId.id option -> - T.erased_region list -> - T.ety list -> - T.const_generic list -> + T.egeneric_args -> V.typed_value -(** Compute an expanded [Option] ⊥ value *) -val compute_expanded_bottom_option_value : - T.VariantId.id -> T.ety -> V.typed_value - (** Drop (end) outer loans at a given place, which should be seen as an l-value (we will write to it later, but need to drop the loans before writing). diff --git a/compiler/InterpreterProjectors.ml b/compiler/InterpreterProjectors.ml index faed066b..9e0c2b75 100644 --- a/compiler/InterpreterProjectors.ml +++ b/compiler/InterpreterProjectors.ml @@ -3,6 +3,7 @@ module V = Values module E = Expressions module C = Contexts module Subst = Substitute +module Assoc = AssociatedTypes module L = Logging open TypesUtils open InterpreterUtils @@ -24,12 +25,12 @@ let rec apply_proj_borrows_on_shared_borrow (ctx : C.eval_ctx) else match (v.V.value, ty) with | V.Literal _, T.Literal _ -> [] - | V.Adt adt, T.Adt (id, region_params, tys, cgs) -> + | V.Adt adt, T.Adt (id, generics) -> (* Retrieve the types of the fields *) let field_types = - Subst.ctx_adt_value_get_instantiated_field_rtypes ctx adt id - region_params tys cgs + Assoc.ctx_adt_value_get_inst_norm_field_rtypes ctx adt id generics in + (* Project over the field values *) let fields_types = List.combine adt.V.field_values field_types in let proj_fields = @@ -103,11 +104,10 @@ let rec apply_proj_borrows (check_symbolic_no_ended : bool) (ctx : C.eval_ctx) let value : V.avalue = match (v.V.value, ty) with | V.Literal _, T.Literal _ -> V.AIgnored - | V.Adt adt, T.Adt (id, region_params, tys, cgs) -> + | V.Adt adt, T.Adt (id, generics) -> (* Retrieve the types of the fields *) let field_types = - Subst.ctx_adt_value_get_instantiated_field_rtypes ctx adt id - region_params tys cgs + Assoc.ctx_adt_value_get_inst_norm_field_rtypes ctx adt id generics in (* Project over the field values *) let fields_types = List.combine adt.V.field_values field_types in @@ -268,8 +268,7 @@ let apply_proj_loans_on_symbolic_expansion (regions : T.RegionId.Set.t) let (value, ty) : V.avalue * T.rty = match (see, original_sv_ty) with | SeLiteral _, T.Literal _ -> (V.AIgnored, original_sv_ty) - | SeAdt (variant_id, field_values), T.Adt (_id, _region_params, _tys, _cgs) - -> + | SeAdt (variant_id, field_values), T.Adt (_id, _generics) -> (* Project over the field values *) let field_values = List.map diff --git a/compiler/InterpreterStatements.ml b/compiler/InterpreterStatements.ml index 045c4484..e0c4703b 100644 --- a/compiler/InterpreterStatements.ml +++ b/compiler/InterpreterStatements.ml @@ -10,13 +10,13 @@ open TypesUtils open ValuesUtils module Inv = Invariants module S = SynthesizeSymbolic -open Utils open Cps open InterpreterUtils open InterpreterProjectors open InterpreterExpansion open InterpreterPaths open InterpreterExpressions +module PCtx = Print.EvalCtxLlbcAst (** The local logger *) let log = L.statements_log @@ -232,9 +232,7 @@ let set_discriminant (config : C.config) (p : E.place) let update_value cf (v : V.typed_value) : m_fun = fun ctx -> match (v.V.ty, v.V.value) with - | ( T.Adt - (((T.AdtId _ | T.Assumed T.Option) as type_id), regions, types, cgs), - V.Adt av ) -> ( + | T.Adt ((T.AdtId _ as type_id), generics), V.Adt av -> ( (* There are two situations: - either the discriminant is already the proper one (in which case we don't do anything) @@ -251,28 +249,17 @@ let set_discriminant (config : C.config) (p : E.place) let bottom_v = match type_id with | T.AdtId def_id -> - compute_expanded_bottom_adt_value - ctx.type_context.type_decls def_id (Some variant_id) - regions types cgs - | T.Assumed T.Option -> - assert (regions = []); - compute_expanded_bottom_option_value variant_id - (Collections.List.to_cons_nil types) + compute_expanded_bottom_adt_value ctx def_id + (Some variant_id) generics | _ -> raise (Failure "Unreachable") in assign_to_place config bottom_v p (cf Unit) ctx) - | ( T.Adt - (((T.AdtId _ | T.Assumed T.Option) as type_id), regions, types, cgs), - V.Bottom ) -> + | T.Adt ((T.AdtId _ as type_id), generics), V.Bottom -> let bottom_v = match type_id with | T.AdtId def_id -> - compute_expanded_bottom_adt_value ctx.type_context.type_decls - def_id (Some variant_id) regions types cgs - | T.Assumed T.Option -> - assert (regions = []); - compute_expanded_bottom_option_value variant_id - (Collections.List.to_cons_nil types) + compute_expanded_bottom_adt_value ctx def_id (Some variant_id) + generics | _ -> raise (Failure "Unreachable") in assign_to_place config bottom_v p (cf Unit) ctx @@ -301,24 +288,34 @@ let ctx_push_frame (ctx : C.eval_ctx) : C.eval_ctx = let push_frame : cm_fun = fun cf ctx -> cf (ctx_push_frame ctx) (** Small helper: compute the type of the return value for a specific - instantiation of a non-local function. + instantiation of an assumed function. *) -let get_non_local_function_return_type (fid : A.assumed_fun_id) - (region_params : T.erased_region list) (type_params : T.ety list) - (const_generic_params : T.const_generic list) : T.ety = +let get_assumed_function_return_type (ctx : C.eval_ctx) (fid : A.assumed_fun_id) + (generics : T.egeneric_args) : T.ety = + assert (generics.trait_refs = []); (* [Box::free] has a special treatment *) - match (fid, region_params, type_params, const_generic_params) with - | A.BoxFree, [], [ _ ], [] -> mk_unit_ty + match fid with + | BoxFree -> + assert (generics.regions = []); + assert (List.length generics.types = 1); + assert (generics.const_generics = []); + mk_unit_ty | _ -> (* Retrieve the function's signature *) - let sg = Assumed.get_assumed_sig fid in + let sg = Assumed.get_assumed_fun_sig fid in (* Instantiate the return type *) - let tsubst = Subst.make_type_subst_from_vars sg.type_params type_params in - let cgsubst = - Subst.make_const_generic_subst_from_vars sg.const_generic_params - const_generic_params + (* There shouldn't be any reference to Self *) + let tr_self : T.erased_region T.trait_instance_id = + T.UnknownTrait __FUNCTION__ + in + let { Subst.r_subst = _; ty_subst; cg_subst; tr_subst; tr_self } = + Subst.make_esubst_from_generics sg.generics generics tr_self + in + let ty = + Subst.erase_regions_substitute_types ty_subst cg_subst tr_subst tr_self + sg.output in - Subst.erase_regions_substitute_types tsubst cgsubst sg.output + Assoc.ctx_normalize_ety ctx ty let move_return_value (config : C.config) (pop_return_value : bool) (cf : V.typed_value option -> m_fun) : m_fun = @@ -418,19 +415,14 @@ let pop_frame_assign (config : C.config) (dest : E.place) : cm_fun = in comp cf_pop cf_assign -(** Auxiliary function - see {!eval_non_local_function_call} *) -let eval_replace_concrete (_config : C.config) - (_region_params : T.erased_region list) (_type_params : T.ety list) - (_cg_params : T.const_generic list) : cm_fun = - fun _cf _ctx -> raise Unimplemented - -(** Auxiliary function - see {!eval_non_local_function_call} *) -let eval_box_new_concrete (config : C.config) - (region_params : T.erased_region list) (type_params : T.ety list) - (cg_params : T.const_generic list) : cm_fun = +(** Auxiliary function - see {!eval_assumed_function_call} *) +let eval_box_new_concrete (config : C.config) (generics : T.egeneric_args) : + cm_fun = fun cf ctx -> (* Check and retrieve the arguments *) - match (region_params, type_params, cg_params, ctx.env) with + match + (generics.regions, generics.types, generics.const_generics, ctx.env) + with | ( [], [ boxed_ty ], [], @@ -448,7 +440,8 @@ let eval_box_new_concrete (config : C.config) (* Create the new box *) let cf_create cf (moved_input_value : V.typed_value) : m_fun = (* Create the box value *) - let box_ty = T.Adt (T.Assumed T.Box, [], [ boxed_ty ], []) in + let generics = TypesUtils.mk_generic_args_from_types [ boxed_ty ] in + let box_ty = T.Adt (T.Assumed T.Box, generics) in let box_v = V.Adt { variant_id = None; field_values = [ moved_input_value ] } in @@ -466,71 +459,7 @@ let eval_box_new_concrete (config : C.config) comp cf_move cf_create cf ctx | _ -> raise (Failure "Inconsistent state") -(** Auxiliary function which factorizes code to evaluate [std::Deref::deref] - and [std::DerefMut::deref_mut] - see {!eval_non_local_function_call} *) -let eval_box_deref_mut_or_shared_concrete (config : C.config) - (region_params : T.erased_region list) (type_params : T.ety list) - (cg_params : T.const_generic list) (is_mut : bool) : cm_fun = - fun cf ctx -> - (* Check the arguments *) - match (region_params, type_params, cg_params, ctx.env) with - | ( [], - [ boxed_ty ], - [], - Var (VarBinder input_var, input_value) - :: Var (_ret_var, _) - :: C.Frame :: _ ) -> - (* Required type checking. We must have: - - input_value.ty = & (mut) Box<ty> - - boxed_ty = ty - for some ty - *) - (let _, input_ty, ref_kind = ty_get_ref input_value.V.ty in - assert (match ref_kind with T.Shared -> not is_mut | T.Mut -> is_mut); - let input_ty = ty_get_box input_ty in - assert (input_ty = boxed_ty)); - - (* Borrow the boxed value *) - let p = - { E.var_id = input_var.C.index; projection = [ E.Deref; E.DerefBox ] } - in - let borrow_kind = if is_mut then E.Mut else E.Shared in - let rv = E.Ref (p, borrow_kind) in - let cf_borrow = eval_rvalue_not_global config rv in - - (* Move the borrow to its destination *) - let cf_move cf res : m_fun = - match res with - | Error EPanic -> - (* We can't get there by borrowing a value *) - raise (Failure "Unreachable") - | Ok borrowed_value -> - (* Move and continue *) - let destp = mk_place_from_var_id E.VarId.zero in - assign_to_place config borrowed_value destp cf - in - - (* Compose and apply *) - comp cf_borrow cf_move cf ctx - | _ -> raise (Failure "Inconsistent state") - -(** Auxiliary function - see {!eval_non_local_function_call} *) -let eval_box_deref_concrete (config : C.config) - (region_params : T.erased_region list) (type_params : T.ety list) - (cg_params : T.const_generic list) : cm_fun = - let is_mut = false in - eval_box_deref_mut_or_shared_concrete config region_params type_params - cg_params is_mut - -(** Auxiliary function - see {!eval_non_local_function_call} *) -let eval_box_deref_mut_concrete (config : C.config) - (region_params : T.erased_region list) (type_params : T.ety list) - (cg_params : T.const_generic list) : cm_fun = - let is_mut = true in - eval_box_deref_mut_or_shared_concrete config region_params type_params - cg_params is_mut - -(** Auxiliary function - see {!eval_non_local_function_call}. +(** Auxiliary function - see {!eval_assumed_function_call}. [Box::free] is not handled the same way as the other assumed functions: - in the regular case, whenever we need to evaluate an assumed function, @@ -549,11 +478,10 @@ let eval_box_deref_mut_concrete (config : C.config) It thus updates the box value (by calling {!drop_value}) and updates the destination (by setting it to [()]). *) -let eval_box_free (config : C.config) (region_params : T.erased_region list) - (type_params : T.ety list) (cg_params : T.const_generic list) +let eval_box_free (config : C.config) (generics : T.egeneric_args) (args : E.operand list) (dest : E.place) : cm_fun = fun cf ctx -> - match (region_params, type_params, cg_params, args) with + match (generics.regions, generics.types, generics.const_generics, args) with | [], [ boxed_ty ], [], [ E.Move input_box_place ] -> (* Required type checking *) let input_box = InterpreterPaths.read_place Write input_box_place ctx in @@ -570,26 +498,24 @@ let eval_box_free (config : C.config) (region_params : T.erased_region list) cc cf ctx | _ -> raise (Failure "Inconsistent state") -(** Auxiliary function - see {!eval_non_local_function_call} *) -let eval_vec_function_concrete (_config : C.config) (_fid : A.assumed_fun_id) - (_region_params : T.erased_region list) (_type_params : T.ety list) - (_cg_params : T.const_generic list) : cm_fun = - fun _cf _ctx -> raise Unimplemented - (** Evaluate a non-local function call in concrete mode *) -let eval_non_local_function_call_concrete (config : C.config) - (fid : A.assumed_fun_id) (region_params : T.erased_region list) - (type_params : T.ety list) (cg_params : T.const_generic list) - (args : E.operand list) (dest : E.place) : cm_fun = +let eval_assumed_function_call_concrete (config : C.config) + (fid : A.assumed_fun_id) (call : A.call) : cm_fun = + let generics = call.func.generics in + let args = call.args in + let dest = call.dest in + (* Sanity check: we don't fully handle the const generic vars environment + in concrete mode yet *) + assert (generics.const_generics = []); (* There are two cases (and this is extremely annoying): - the function is not box_free - the function is box_free See {!eval_box_free} *) match fid with - | A.BoxFree -> + | BoxFree -> (* Degenerate case: box_free *) - eval_box_free config region_params type_params cg_params args dest + eval_box_free config generics args dest | _ -> (* "Normal" case: not box_free *) (* Evaluate the operands *) @@ -604,16 +530,14 @@ let eval_non_local_function_call_concrete (config : C.config) * but it made it less clear where the computed values came from, * so we reversed the modifications. *) let cf_eval_call cf (args_vl : V.typed_value list) : m_fun = + fun ctx -> (* Push the stack frame: we initialize the frame with the return variable, and one variable per input argument *) let cc = push_frame in (* Create and push the return variable *) let ret_vid = E.VarId.zero in - let ret_ty = - get_non_local_function_return_type fid region_params type_params - cg_params - in + let ret_ty = get_assumed_function_return_type ctx fid generics in let ret_var = mk_var ret_vid (Some "@return") ret_ty in let cc = comp cc (push_uninitialized_var ret_var) in @@ -630,24 +554,12 @@ let eval_non_local_function_call_concrete (config : C.config) * access to a body. *) let cf_eval_body : cm_fun = match fid with - | A.Replace -> - eval_replace_concrete config region_params type_params cg_params - | BoxNew -> - eval_box_new_concrete config region_params type_params cg_params - | BoxDeref -> - eval_box_deref_concrete config region_params type_params cg_params - | BoxDerefMut -> - eval_box_deref_mut_concrete config region_params type_params - cg_params + | BoxNew -> eval_box_new_concrete config generics | BoxFree -> (* Should have been treated above *) raise (Failure "Unreachable") - | VecNew | VecPush | VecInsert | VecLen | VecIndex | VecIndexMut -> - eval_vec_function_concrete config fid region_params type_params - cg_params | ArrayIndexShared | ArrayIndexMut | ArrayToSliceShared - | ArrayToSliceMut | ArraySubsliceShared | ArraySubsliceMut - | SliceIndexShared | SliceIndexMut | SliceSubsliceShared - | SliceSubsliceMut | SliceLen -> + | ArrayToSliceMut | ArrayRepeat | SliceIndexShared | SliceIndexMut + | SliceLen -> raise (Failure "Unimplemented") in @@ -657,50 +569,11 @@ let eval_non_local_function_call_concrete (config : C.config) let cc = comp cc (pop_frame_assign config dest) in (* Continue *) - cc cf + cc cf ctx in (* Compose and apply *) comp cf_eval_ops cf_eval_call -let instantiate_fun_sig (type_params : T.ety list) - (cg_params : T.const_generic list) (sg : A.fun_sig) : A.inst_fun_sig = - (* Generate fresh abstraction ids and create a substitution from region - * group ids to abstraction ids *) - let rg_abs_ids_bindings = - List.map - (fun rg -> - let abs_id = C.fresh_abstraction_id () in - (rg.T.id, abs_id)) - sg.regions_hierarchy - in - let asubst_map : V.AbstractionId.id T.RegionGroupId.Map.t = - List.fold_left - (fun mp (rg_id, abs_id) -> T.RegionGroupId.Map.add rg_id abs_id mp) - T.RegionGroupId.Map.empty rg_abs_ids_bindings - in - let asubst (rg_id : T.RegionGroupId.id) : V.AbstractionId.id = - T.RegionGroupId.Map.find rg_id asubst_map - in - (* Generate fresh regions and their substitutions *) - let _, rsubst, _ = Subst.fresh_regions_with_substs sg.region_params in - (* Generate the type substitution - * Note that we need the substitution to map the type variables to - * {!rty} types (not {!ety}). In order to do that, we convert the - * type parameters to types with regions. This is possible only - * if those types don't contain any regions. - * This is a current limitation of the analysis: there is still some - * work to do to properly handle full type parametrization. - * *) - let rtype_params = List.map ety_no_regions_to_rty type_params in - let tsubst = Subst.make_type_subst_from_vars sg.type_params rtype_params in - let cgsubst = - Subst.make_const_generic_subst_from_vars sg.const_generic_params cg_params - in - (* Substitute the signature *) - let inst_sig = Subst.substitute_signature asubst rsubst tsubst cgsubst sg in - (* Return *) - inst_sig - (** Helper Create abstractions (with no avalues, which have to be inserted afterwards) @@ -836,7 +709,7 @@ let rec eval_statement (config : C.config) (st : A.statement) : st_cm_fun = match rvalue with | E.Global _ -> raise (Failure "Unreachable") | E.Use _ - | E.Ref (_, (E.Shared | E.Mut | E.TwoPhaseMut | E.Shallow)) + | E.RvRef (_, (E.Shared | E.Mut | E.TwoPhaseMut | E.Shallow)) | E.UnaryOp _ | E.BinaryOp _ | E.Discriminant _ | E.Aggregate _ -> let rp = rvalue_get_place rvalue in @@ -893,7 +766,15 @@ and eval_global (config : C.config) (dest : E.place) (gid : LA.GlobalDeclId.id) match config.mode with | ConcreteMode -> (* Treat the evaluation of the global as a call to the global body (without arguments) *) - (eval_local_function_call_concrete config global.body_id [] [] [] [] dest) + let func = + { + E.func = FunId (Regular global.body_id); + generics = TypesUtils.mk_empty_generic_args; + trait_and_method_generic_args = None; + } + in + let call = { A.func; args = []; dest } in + (eval_transparent_function_call_concrete config global.body_id call) cf ctx | SymbolicMode -> (* Generate a fresh symbolic value. In the translation, this fresh symbolic value will be @@ -1037,128 +918,374 @@ and eval_switch (config : C.config) (switch : A.switch) : st_cm_fun = (** Evaluate a function call (auxiliary helper for [eval_statement]) *) and eval_function_call (config : C.config) (call : A.call) : st_cm_fun = - (* There are two cases: + (* There are several cases: - this is a local function, in which case we execute its body - - this is a non-local function, in which case there is a special treatment + - this is an assumed function, in which case there is a special treatment + - this is a trait method *) - match call.func with - | A.Regular fid -> - eval_local_function_call config fid call.region_args call.type_args - call.const_generic_args call.args call.dest - | A.Assumed fid -> - eval_non_local_function_call config fid call.region_args call.type_args - call.const_generic_args call.args call.dest + match config.mode with + | C.ConcreteMode -> eval_function_call_concrete config call + | C.SymbolicMode -> eval_function_call_symbolic config call -(** Evaluate a local (i.e., non-assumed) function call in concrete mode *) -and eval_local_function_call_concrete (config : C.config) (fid : A.FunDeclId.id) - (_region_args : T.erased_region list) (type_args : T.ety list) - (cg_args : T.const_generic list) (args : E.operand list) (dest : E.place) : - st_cm_fun = +and eval_function_call_concrete (config : C.config) (call : A.call) : st_cm_fun + = fun cf ctx -> - (* Retrieve the (correctly instantiated) body *) - let def = C.ctx_lookup_fun_decl ctx fid in - (* We can evaluate the function call only if it is not opaque *) - let body = - match def.body with - | None -> - raise - (Failure - ("Can't evaluate a call to an opaque function: " - ^ Print.name_to_string def.name)) - | Some body -> body - in - let tsubst = - Subst.make_type_subst_from_vars def.A.signature.type_params type_args - in - let cgsubst = - Subst.make_const_generic_subst_from_vars - def.A.signature.const_generic_params cg_args - in - let locals, body_st = Subst.fun_body_substitute_in_body tsubst cgsubst body in - - (* Evaluate the input operands *) - assert (List.length args = body.A.arg_count); - let cc = eval_operands config args in - - (* Push a frame delimiter - we use {!comp_transmit} to transmit the result - * of the operands evaluation from above to the functions afterwards, while - * ignoring it in this function *) - let cc = comp_transmit cc push_frame in - - (* Compute the initial values for the local variables *) - (* 1. Push the return value *) - let ret_var, locals = - match locals with - | ret_ty :: locals -> (ret_ty, locals) - | _ -> raise (Failure "Unreachable") - in - let input_locals, locals = - Collections.List.split_at locals body.A.arg_count - in + match call.func.func with + | FunId (Regular fid) -> + eval_transparent_function_call_concrete config fid call cf ctx + | FunId (Assumed fid) -> + (* Continue - note that we do as if the function call has been successful, + * by giving {!Unit} to the continuation, because we place us in the case + * where we haven't panicked. Of course, the translation needs to take the + * panic case into account... *) + eval_assumed_function_call_concrete config fid call (cf Unit) ctx + | TraitMethod _ -> raise (Failure "Unimplemented") + +and eval_function_call_symbolic (config : C.config) (call : A.call) : st_cm_fun + = + match call.func.func with + | FunId (Regular _) | TraitMethod _ -> + eval_transparent_function_call_symbolic config call + | FunId (Assumed fid) -> eval_assumed_function_call_symbolic config fid call - let cc = comp_transmit cc (push_var ret_var (mk_bottom ret_var.var_ty)) in - - (* 2. Push the input values *) - let cf_push_inputs cf args = - let inputs = List.combine input_locals args in - (* Note that this function checks that the variables and their values - * have the same type (this is important) *) - push_vars inputs cf - in - let cc = comp cc cf_push_inputs in - - (* 3. Push the remaining local variables (initialized as {!Bottom}) *) - let cc = comp cc (push_uninitialized_vars locals) in +(** Evaluate a local (i.e., non-assumed) function call in concrete mode *) +and eval_transparent_function_call_concrete (config : C.config) + (fid : A.FunDeclId.id) (call : A.call) : st_cm_fun = + let generics = call.func.generics in + let args = call.A.args in + let dest = call.A.dest in + (* Sanity check: we don't fully handle the const generic vars environment + in concrete mode yet *) + assert (generics.const_generics = []); + fun cf ctx -> + (* Retrieve the (correctly instantiated) body *) + let def = C.ctx_lookup_fun_decl ctx fid in + (* We can evaluate the function call only if it is not opaque *) + let body = + match def.body with + | None -> + raise + (Failure + ("Can't evaluate a call to an opaque function: " + ^ Print.name_to_string def.name)) + | Some body -> body + in + (* TODO: we need to normalize the types if we want to correctly support traits *) + assert (generics.trait_refs = []); + (* There shouldn't be any reference to Self *) + let tr_self = T.UnknownTrait __FUNCTION__ in + let subst = + Subst.make_esubst_from_generics def.A.signature.generics generics tr_self + in + let locals, body_st = Subst.fun_body_substitute_in_body subst body in + + (* Evaluate the input operands *) + assert (List.length args = body.A.arg_count); + let cc = eval_operands config args in + + (* Push a frame delimiter - we use {!comp_transmit} to transmit the result + * of the operands evaluation from above to the functions afterwards, while + * ignoring it in this function *) + let cc = comp_transmit cc push_frame in + + (* Compute the initial values for the local variables *) + (* 1. Push the return value *) + let ret_var, locals = + match locals with + | ret_ty :: locals -> (ret_ty, locals) + | _ -> raise (Failure "Unreachable") + in + let input_locals, locals = + Collections.List.split_at locals body.A.arg_count + in - (* Execute the function body *) - let cc = comp cc (eval_function_body config body_st) in + let cc = comp_transmit cc (push_var ret_var (mk_bottom ret_var.var_ty)) in - (* Pop the stack frame and move the return value to its destination *) - let cf_finish cf res = - match res with - | Panic -> cf Panic - | Return -> - (* Pop the stack frame, retrieve the return value, move it to - * its destination and continue *) - pop_frame_assign config dest (cf Unit) - | Break _ | Continue _ | Unit | LoopReturn _ | EndEnterLoop _ - | EndContinue _ -> - raise (Failure "Unreachable") - in - let cc = comp cc cf_finish in + (* 2. Push the input values *) + let cf_push_inputs cf args = + let inputs = List.combine input_locals args in + (* Note that this function checks that the variables and their values + * have the same type (this is important) *) + push_vars inputs cf + in + let cc = comp cc cf_push_inputs in + + (* 3. Push the remaining local variables (initialized as {!Bottom}) *) + let cc = comp cc (push_uninitialized_vars locals) in + + (* Execute the function body *) + let cc = comp cc (eval_function_body config body_st) in + + (* Pop the stack frame and move the return value to its destination *) + let cf_finish cf res = + match res with + | Panic -> cf Panic + | Return -> + (* Pop the stack frame, retrieve the return value, move it to + * its destination and continue *) + pop_frame_assign config dest (cf Unit) + | Break _ | Continue _ | Unit | LoopReturn _ | EndEnterLoop _ + | EndContinue _ -> + raise (Failure "Unreachable") + in + let cc = comp cc cf_finish in - (* Continue *) - cc cf ctx + (* Continue *) + cc cf ctx (** Evaluate a local (i.e., non-assumed) function call in symbolic mode *) -and eval_local_function_call_symbolic (config : C.config) (fid : A.FunDeclId.id) - (region_args : T.erased_region list) (type_args : T.ety list) - (cg_args : T.const_generic list) (args : E.operand list) (dest : E.place) : - st_cm_fun = +and eval_transparent_function_call_symbolic (config : C.config) (call : A.call) + : st_cm_fun = fun cf ctx -> - (* Retrieve the (correctly instantiated) signature *) - let def = C.ctx_lookup_fun_decl ctx fid in - let sg = def.A.signature in - (* Instantiate the signature and introduce fresh abstraction and region ids - * while doing so *) - let inst_sg = instantiate_fun_sig type_args cg_args sg in + (* Instantiate the signature and introduce fresh abstractions and region ids while doing so. + + We perform some manipulations when instantiating the signature. + + # Trait impl calls + ================== + In particular, we have a special treatment of trait method calls when + the trait ref is a known impl. + + For instance: + {[ + trait HasValue { + fn has_value(&self) -> bool; + } + + impl<T> HasValue for Option<T> { + fn has_value(&self) { + match self { + None => false, + Some(_) => true, + } + } + } + + fn option_has_value<T>(x: &Option<T>) -> bool { + x.has_value() + } + ]} + + The generated code looks like this: + {[ + structure HasValue (Self : Type) = { + has_value : Self -> result bool + } + + let OptionHasValueImpl.has_value (Self : Type) (self : Self) : result bool = + match self with + | None => false + | Some _ => true + + let OptionHasValueInstance (T : Type) : HasValue (Option T) = { + has_value = OptionHasValueInstance.has_value + } + ]} + + In [option_has_value], we don't want to refer to the [has_value] method + of the instance of [HasValue] for [Option<T>]. We want to refer directly + to the function which implements [has_value] for [Option<T>]. + That is, instead of generating this: + {[ + let option_has_value (T : Type) (x : Option T) : result bool = + (OptionHasValueInstance T).has_value x + ]} + + We want to generate this: + {[ + let option_has_value (T : Type) (x : Option T) : result bool = + OptionHasValueImpl.has_value T x + ]} + + # Provided trait methods + ======================== + Calls to provided trait methods also have a special treatment because + for now we forbid overriding provided trait methods in the trait implementations, + which means that whenever we call a provided trait method, we do not refer + to a trait clause but directly to the method provided in the trait declaration. + *) + let func, generics, def, inst_sg = + match call.func.func with + | FunId (Regular fid) -> + let def = C.ctx_lookup_fun_decl ctx fid in + log#ldebug + (lazy + ("fun call:\n- call: " ^ call_to_string ctx call + ^ "\n- call.generics:\n" + ^ egeneric_args_to_string ctx call.func.generics + ^ "\n- def.signature:\n" + ^ fun_sig_to_string ctx def.A.signature)); + let tr_self = T.UnknownTrait __FUNCTION__ in + let inst_sg = + instantiate_fun_sig ctx call.func.generics tr_self def.A.signature + in + (call.func.func, call.func.generics, def, inst_sg) + | FunId (Assumed _) -> + (* Unreachable: must be a transparent function *) + raise (Failure "Unreachable") + | TraitMethod (trait_ref, method_name, _) -> ( + log#ldebug + (lazy + ("trait method call:\n- call: " ^ call_to_string ctx call + ^ "\n- method name: " ^ method_name ^ "\n- call.generics:\n" + ^ egeneric_args_to_string ctx call.func.generics + ^ "\n- trait and method generics:\n" + ^ egeneric_args_to_string ctx + (Option.get call.func.trait_and_method_generic_args))); + (* When instantiating, we need to group the generics for the trait ref + and the method *) + let generics = Option.get call.func.trait_and_method_generic_args in + (* Lookup the trait method signature - there are several possibilities + depending on whethere we call a top-level trait method impl or the + method from a local clause *) + match trait_ref.trait_id with + | TraitImpl impl_id -> ( + (* Lookup the trait impl *) + let trait_impl = C.ctx_lookup_trait_impl ctx impl_id in + log#ldebug + (lazy ("trait impl: " ^ trait_impl_to_string ctx trait_impl)); + (* First look in the required methods *) + let method_id = + List.find_opt + (fun (s, _) -> s = method_name) + trait_impl.required_methods + in + match method_id with + | Some (_, id) -> + (* This is a required method *) + let method_def = C.ctx_lookup_fun_decl ctx id in + (* Instantiate *) + let tr_self = + T.TraitRef (etrait_ref_no_regions_to_gr_trait_ref trait_ref) + in + let inst_sg = + instantiate_fun_sig ctx generics tr_self + method_def.A.signature + in + (* Also update the function identifier: we want to forget + the fact that we called a trait method, and treat it as + a regular function call to the top-level function + which implements the method. In order to do this properly, + we also need to update the generics. + *) + let func = E.FunId (Regular id) in + (func, generics, method_def, inst_sg) + | None -> + (* If not found, lookup the methods provided by the trait *declaration* + (remember: for now, we forbid overriding provided methods) *) + assert (trait_impl.provided_methods = []); + let trait_decl = + C.ctx_lookup_trait_decl ctx + trait_ref.trait_decl_ref.trait_decl_id + in + let _, method_id = + List.find + (fun (s, _) -> s = method_name) + trait_decl.provided_methods + in + let method_id = Option.get method_id in + let method_def = C.ctx_lookup_fun_decl ctx method_id in + (* For the instantiation we have to do something peculiar + because the method was defined for the trait declaration. + We have to group: + - the parameters given to the trait decl reference + - the parameters given to the method itself + For instance: + {[ + trait Foo<T> { + fn f<U>(...) { ... } + } + + fn g<G>(x : G) where Clause0: Foo<G, bool> + { + x.f::<u32>(...) // The arguments to f are: <G, bool, u32> + } + ]} + *) + let all_generics = + TypesUtils.merge_generic_args + trait_ref.trait_decl_ref.decl_generics call.func.generics + in + log#ldebug + (lazy + ("provided method call:" ^ "\n- method name: " ^ method_name + ^ "\n- all_generics:\n" + ^ egeneric_args_to_string ctx all_generics + ^ "\n- parent params info: " + ^ Print.option_to_string A.show_params_info + method_def.signature.parent_params_info)); + let tr_self = + T.TraitRef (etrait_ref_no_regions_to_gr_trait_ref trait_ref) + in + let inst_sg = + instantiate_fun_sig ctx all_generics tr_self + method_def.A.signature + in + (call.func.func, call.func.generics, method_def, inst_sg)) + | _ -> + (* We are using a local clause - we lookup the trait decl *) + let trait_decl = + C.ctx_lookup_trait_decl ctx trait_ref.trait_decl_ref.trait_decl_id + in + (* Lookup the method decl in the required *and* the provided methods *) + let _, method_id = + let provided = + List.filter_map + (fun (id, f) -> + match f with None -> None | Some f -> Some (id, f)) + trait_decl.provided_methods + in + List.find + (fun (s, _) -> s = method_name) + (List.append trait_decl.required_methods provided) + in + let method_def = C.ctx_lookup_fun_decl ctx method_id in + log#ldebug (lazy ("method:\n" ^ fun_decl_to_string ctx method_def)); + (* Instantiate *) + let tr_self = T.TraitRef trait_ref in + let tr_self = + TypesUtils.etrait_instance_id_no_regions_to_gr_trait_instance_id + tr_self + in + let inst_sg = + instantiate_fun_sig ctx generics tr_self method_def.A.signature + in + (call.func.func, call.func.generics, method_def, inst_sg)) + in (* Sanity check *) - assert (List.length args = List.length def.A.signature.inputs); + assert (List.length call.args = List.length def.A.signature.inputs); (* Evaluate the function call *) - eval_function_call_symbolic_from_inst_sig config (A.Regular fid) inst_sg - region_args type_args cg_args args dest cf ctx + eval_function_call_symbolic_from_inst_sig config func inst_sg generics + call.args call.dest cf ctx (** Evaluate a function call in symbolic mode by using the function signature. This allows us to factorize the evaluation of local and non-local function calls in symbolic mode: only their signatures matter. + + The [self_trait_ref] trait ref refers to [Self]. We use it when calling + a provided trait method, because those methods have a special treatment: + we dot not group them with the required trait methods, and forbid (for now) + overriding them. We treat them as regular method, which take an additional + trait ref as input. *) and eval_function_call_symbolic_from_inst_sig (config : C.config) - (fid : A.fun_id) (inst_sg : A.inst_fun_sig) - (_region_args : T.erased_region list) (type_args : T.ety list) - (cg_args : T.const_generic list) (args : E.operand list) (dest : E.place) : + (fid : A.fun_id_or_trait_method_ref) (inst_sg : A.inst_fun_sig) + (generics : T.egeneric_args) (args : E.operand list) (dest : E.place) : st_cm_fun = fun cf ctx -> + log#ldebug + (lazy + ("eval_function_call_symbolic_from_inst_sig:\n- fid: " + ^ fun_id_or_trait_method_ref_to_string ctx fid + ^ "\n- inst_sg:\n" + ^ inst_fun_sig_to_string ctx inst_sg + ^ "\n- call.generics:\n" + ^ egeneric_args_to_string ctx generics + ^ "\n- args:\n" + ^ String.concat ", " (List.map (operand_to_string ctx) args) + ^ "\n- dest:\n" ^ place_to_string ctx dest)); + (* Generate a fresh symbolic value for the return value *) let ret_sv_ty = inst_sg.A.output in let ret_spc = mk_fresh_symbolic_value V.FunCallRet ret_sv_ty in @@ -1224,8 +1351,8 @@ and eval_function_call_symbolic_from_inst_sig (config : C.config) let expr = cf ctx in (* Synthesize the symbolic AST *) - S.synthesize_regular_function_call fid call_id ctx abs_ids type_args cg_args - args args_places ret_spc dest_place expr + S.synthesize_regular_function_call fid call_id ctx abs_ids generics args + args_places ret_spc dest_place expr in let cc = comp cc cf_call in @@ -1286,17 +1413,18 @@ and eval_function_call_symbolic_from_inst_sig (config : C.config) cc (cf Unit) ctx (** Evaluate a non-local function call in symbolic mode *) -and eval_non_local_function_call_symbolic (config : C.config) - (fid : A.assumed_fun_id) (region_args : T.erased_region list) - (type_args : T.ety list) (cg_args : T.const_generic list) - (args : E.operand list) (dest : E.place) : st_cm_fun = +and eval_assumed_function_call_symbolic (config : C.config) + (fid : A.assumed_fun_id) (call : A.call) : st_cm_fun = fun cf ctx -> + let generics = call.func.generics in + let args = call.args in + let dest = call.dest in (* Sanity check: make sure the type parameters don't contain regions - * this is a current limitation of our synthesis *) assert ( List.for_all (fun ty -> not (ty_has_borrows ctx.type_context.type_infos ty)) - type_args); + generics.types); (* There are two cases (and this is extremely annoying): - the function is not box_free @@ -1304,10 +1432,10 @@ and eval_non_local_function_call_symbolic (config : C.config) See {!eval_box_free} *) match fid with - | A.BoxFree -> + | BoxFree -> (* Degenerate case: box_free - note that this is not really a function * call: no need to call a "synthesize_..." function *) - eval_box_free config region_args type_args cg_args args dest (cf Unit) ctx + eval_box_free config generics args dest (cf Unit) ctx | _ -> (* "Normal" case: not box_free *) (* In symbolic mode, the behaviour of a function call is completely defined @@ -1315,59 +1443,19 @@ and eval_non_local_function_call_symbolic (config : C.config) * instantiated signatures, and delegate the work to an auxiliary function *) let inst_sig = match fid with - | A.BoxFree -> + | BoxFree -> (* should have been treated above *) raise (Failure "Unreachable") | _ -> - instantiate_fun_sig type_args cg_args (Assumed.get_assumed_sig fid) + (* There shouldn't be any reference to Self *) + let tr_self = T.UnknownTrait __FUNCTION__ in + instantiate_fun_sig ctx generics tr_self + (Assumed.get_assumed_fun_sig fid) in (* Evaluate the function call *) - eval_function_call_symbolic_from_inst_sig config (A.Assumed fid) inst_sig - region_args type_args cg_args args dest cf ctx - -(** Evaluate a non-local (i.e, assumed) function call such as [Box::deref] - (auxiliary helper for [eval_statement]) *) -and eval_non_local_function_call (config : C.config) (fid : A.assumed_fun_id) - (region_args : T.erased_region list) (type_args : T.ety list) - (cg_args : T.const_generic list) (args : E.operand list) (dest : E.place) : - st_cm_fun = - fun cf ctx -> - (* Debug *) - log#ldebug - (lazy - (let type_args = - "[" ^ String.concat ", " (List.map (ety_to_string ctx) type_args) ^ "]" - in - let args = - "[" ^ String.concat ", " (List.map (operand_to_string ctx) args) ^ "]" - in - let dest = place_to_string ctx dest in - "eval_non_local_function_call:\n- fid:" ^ A.show_assumed_fun_id fid - ^ "\n- type_args: " ^ type_args ^ "\n- args: " ^ args ^ "\n- dest: " - ^ dest)); - - match config.mode with - | C.ConcreteMode -> - eval_non_local_function_call_concrete config fid region_args type_args - cg_args args dest (cf Unit) ctx - | C.SymbolicMode -> - eval_non_local_function_call_symbolic config fid region_args type_args - cg_args args dest cf ctx - -(** Evaluate a local (i.e, not assumed) function call (auxiliary helper for - [eval_statement]) *) -and eval_local_function_call (config : C.config) (fid : A.FunDeclId.id) - (region_args : T.erased_region list) (type_args : T.ety list) - (cg_args : T.const_generic list) (args : E.operand list) (dest : E.place) : - st_cm_fun = - match config.mode with - | ConcreteMode -> - eval_local_function_call_concrete config fid region_args type_args cg_args - args dest - | SymbolicMode -> - eval_local_function_call_symbolic config fid region_args type_args cg_args - args dest + eval_function_call_symbolic_from_inst_sig config (FunId (Assumed fid)) + inst_sig generics args dest cf ctx (** Evaluate a statement seen as a function body *) and eval_function_body (config : C.config) (body : A.statement) : st_cm_fun = diff --git a/compiler/InterpreterStatements.mli b/compiler/InterpreterStatements.mli index 814bc964..e65758ae 100644 --- a/compiler/InterpreterStatements.mli +++ b/compiler/InterpreterStatements.mli @@ -25,15 +25,6 @@ open InterpreterExpressions *) val pop_frame : C.config -> bool -> (V.typed_value option -> m_fun) -> m_fun -(** Instantiate a function signature, introducing **fresh** abstraction ids and - region ids. This is mostly used in preparation of function calls, when - evaluating in symbolic mode of course. - - Note: there are no region parameters, because they should be erased. - *) -val instantiate_fun_sig : - T.ety list -> T.const_generic list -> LA.fun_sig -> LA.inst_fun_sig - (** Helper. Create a list of abstractions from a list of regions groups, and insert diff --git a/compiler/InterpreterUtils.ml b/compiler/InterpreterUtils.ml index 7bd37550..6e08e553 100644 --- a/compiler/InterpreterUtils.ml +++ b/compiler/InterpreterUtils.ml @@ -10,6 +10,11 @@ open TypesUtils module PA = Print.EvalCtxLlbcAst open Cps +(* TODO: we should probably rename the file to ContextsUtils *) + +(** The local logger *) +let log = L.interpreter_log + (** Some utilities *) (** Auxiliary function - call a function which requires a continuation, @@ -38,6 +43,20 @@ let typed_value_to_string = PA.typed_value_to_string let typed_avalue_to_string = PA.typed_avalue_to_string let place_to_string = PA.place_to_string let operand_to_string = PA.operand_to_string +let egeneric_args_to_string = PA.egeneric_args_to_string +let rtrait_instance_id_to_string = PA.rtrait_instance_id_to_string +let fun_sig_to_string = PA.fun_sig_to_string +let inst_fun_sig_to_string = PA.inst_fun_sig_to_string + +let fun_id_or_trait_method_ref_to_string = + PA.fun_id_or_trait_method_ref_to_string + +let fun_decl_to_string = PA.fun_decl_to_string +let call_to_string = PA.call_to_string + +let trait_impl_to_string ctx = + PA.trait_impl_to_string { ctx with type_vars = []; const_generic_vars = [] } + let statement_to_string ctx = PA.statement_to_string ctx "" " " let statement_to_string_with_tab ctx = PA.statement_to_string ctx " " " " let env_elem_to_string ctx = PA.env_elem_to_string ctx "" " " @@ -255,7 +274,8 @@ let value_has_ret_symbolic_value_with_borrow_under_mut (ctx : C.eval_ctx) raise Found else () | V.SynthInput | V.SynthInputGivenBack | V.FunCallGivenBack - | V.SynthRetGivenBack | V.Global | V.LoopGivenBack | V.Aggregate -> + | V.SynthRetGivenBack | V.Global | V.LoopGivenBack | V.Aggregate + | V.ConstGeneric | V.TraitConst -> () end in @@ -272,7 +292,7 @@ let rvalue_get_place (rv : E.rvalue) : E.place option = match rv with | Use (Copy p | Move p) -> Some p | Use (Constant _) -> None - | Ref (p, _) -> Some p + | RvRef (p, _) -> Some p | UnaryOp _ | BinaryOp _ | Global _ | Discriminant _ | Aggregate _ -> None (** See {!ValuesUtils.symbolic_value_has_borrows} *) @@ -403,3 +423,103 @@ let compute_contexts_ids (ctxl : C.eval_ctx list) : ids_sets * ids_to_values = (** Compute the sets of ids found in a context. *) let compute_context_ids (ctx : C.eval_ctx) : ids_sets * ids_to_values = compute_contexts_ids [ ctx ] + +(** **WARNING**: this function doesn't compute the normalized types + (for the trait type aliases). This should be computed afterwards. + *) +let initialize_eval_context (ctx : C.decls_ctx) + (region_groups : T.RegionGroupId.id list) (type_vars : T.type_var list) + (const_generic_vars : T.const_generic_var list) : C.eval_ctx = + C.reset_global_counters (); + let const_generic_vars_map = + T.ConstGenericVarId.Map.of_list + (List.map + (fun (cg : T.const_generic_var) -> + let ty = TypesUtils.ety_no_regions_to_rty (T.Literal cg.ty) in + let cv = mk_fresh_symbolic_typed_value V.ConstGeneric ty in + (cg.index, cv)) + const_generic_vars) + in + { + C.type_context = ctx.type_ctx; + C.fun_context = ctx.fun_ctx; + C.global_context = ctx.global_ctx; + C.trait_decls_context = ctx.trait_decls_ctx; + C.trait_impls_context = ctx.trait_impls_ctx; + C.region_groups; + C.type_vars; + C.const_generic_vars; + C.const_generic_vars_map; + C.norm_trait_etypes = C.ETraitTypeRefMap.empty (* Empty for now *); + C.norm_trait_rtypes = C.RTraitTypeRefMap.empty (* Empty for now *); + C.norm_trait_stypes = C.STraitTypeRefMap.empty (* Empty for now *); + C.env = [ C.Frame ]; + C.ended_regions = T.RegionId.Set.empty; + } + +(** Instantiate a function signature, introducing **fresh** abstraction ids and + region ids. This is mostly used in preparation of function calls (when + evaluating in symbolic mode). + + Note: there are no region parameters, because they should be erased. + *) +let instantiate_fun_sig (ctx : C.eval_ctx) (generics : T.egeneric_args) + (tr_self : T.rtrait_instance_id) (sg : A.fun_sig) : A.inst_fun_sig = + log#ldebug + (lazy + ("instantiate_fun_sig:" ^ "\n- generics: " + ^ egeneric_args_to_string ctx generics + ^ "\n- tr_self: " + ^ rtrait_instance_id_to_string ctx tr_self + ^ "\n- sg: " ^ fun_sig_to_string ctx sg)); + (* Generate fresh abstraction ids and create a substitution from region + * group ids to abstraction ids *) + let rg_abs_ids_bindings = + List.map + (fun rg -> + let abs_id = C.fresh_abstraction_id () in + (rg.T.id, abs_id)) + sg.regions_hierarchy + in + let asubst_map : V.AbstractionId.id T.RegionGroupId.Map.t = + List.fold_left + (fun mp (rg_id, abs_id) -> T.RegionGroupId.Map.add rg_id abs_id mp) + T.RegionGroupId.Map.empty rg_abs_ids_bindings + in + let asubst (rg_id : T.RegionGroupId.id) : V.AbstractionId.id = + T.RegionGroupId.Map.find rg_id asubst_map + in + (* Generate fresh regions and their substitutions *) + let _, rsubst, _ = Subst.fresh_regions_with_substs sg.generics.regions in + (* Generate the type substitution + * Note that we need the substitution to map the type variables to + * {!rty} types (not {!ety}). In order to do that, we convert the + * type parameters to types with regions. This is possible only + * if those types don't contain any regions. + * This is a current limitation of the analysis: there is still some + * work to do to properly handle full type parametrization. + * *) + let rtype_params = List.map ety_no_regions_to_rty generics.types in + let tsubst = Subst.make_type_subst_from_vars sg.generics.types rtype_params in + let cgsubst = + Subst.make_const_generic_subst_from_vars sg.generics.const_generics + generics.const_generics + in + (* TODO: something annoying with the trait ref subst: we need to use region + types, but the arguments use erased regions. For now we use the fact + that no regions should appear inside. In the future: we should merge + ety and rty. *) + let trait_refs = + List.map TypesUtils.etrait_ref_no_regions_to_gr_trait_ref + generics.trait_refs + in + let tr_subst = + Subst.make_trait_subst_from_clauses sg.generics.trait_clauses trait_refs + in + (* Substitute the signature *) + let inst_sig = + AssociatedTypes.ctx_subst_norm_signature ctx asubst rsubst tsubst cgsubst + tr_subst tr_self sg + in + (* Return *) + inst_sig diff --git a/compiler/Invariants.ml b/compiler/Invariants.ml index f29c7f88..5c8ec7af 100644 --- a/compiler/Invariants.ml +++ b/compiler/Invariants.ml @@ -7,6 +7,7 @@ module V = Values module E = Expressions module C = Contexts module Subst = Substitute +module Assoc = AssociatedTypes module A = LlbcAst module L = Logging open Cps @@ -406,13 +407,14 @@ let check_typing_invariant (ctx : C.eval_ctx) : unit = (match (tv.V.value, tv.V.ty) with | V.Literal cv, T.Literal ty -> check_literal_type cv ty (* ADT case *) - | V.Adt av, T.Adt (T.AdtId def_id, regions, tys, cgs) -> + | V.Adt av, T.Adt (T.AdtId def_id, generics) -> (* Retrieve the definition to check the variant id, the number of * parameters, etc. *) let def = C.ctx_lookup_type_decl ctx def_id in (* Check the number of parameters *) - assert (List.length regions = List.length def.region_params); - assert (List.length tys = List.length def.type_params); + assert ( + List.length generics.regions = List.length def.generics.regions); + assert (List.length generics.types = List.length def.generics.types); (* Check that the variant id is consistent *) (match (av.V.variant_id, def.T.kind) with | Some variant_id, T.Enum variants -> @@ -421,8 +423,8 @@ let check_typing_invariant (ctx : C.eval_ctx) : unit = | _ -> raise (Failure "Erroneous typing")); (* Check that the field types are correct *) let field_types = - Subst.type_decl_get_instantiated_field_etypes def av.V.variant_id - tys cgs + Assoc.type_decl_get_inst_norm_field_etypes ctx def av.V.variant_id + generics in let fields_with_types = List.combine av.V.field_values field_types @@ -431,34 +433,31 @@ let check_typing_invariant (ctx : C.eval_ctx) : unit = (fun ((v, ty) : V.typed_value * T.ety) -> assert (v.V.ty = ty)) fields_with_types (* Tuple case *) - | V.Adt av, T.Adt (T.Tuple, regions, tys, cgs) -> - assert (regions = []); - assert (cgs = []); + | V.Adt av, T.Adt (T.Tuple, generics) -> + assert (generics.regions = []); + assert (generics.const_generics = []); assert (av.V.variant_id = None); (* Check that the fields have the proper values - and check that there * are as many fields as field types at the same time *) - let fields_with_types = List.combine av.V.field_values tys in + let fields_with_types = + List.combine av.V.field_values generics.types + in List.iter (fun ((v, ty) : V.typed_value * T.ety) -> assert (v.V.ty = ty)) fields_with_types (* Assumed type case *) - | V.Adt av, T.Adt (T.Assumed aty_id, regions, tys, cgs) -> ( - assert (av.V.variant_id = None || aty_id = T.Option); - match (aty_id, av.V.field_values, regions, tys, cgs) with + | V.Adt av, T.Adt (T.Assumed aty_id, generics) -> ( + assert (av.V.variant_id = None); + match + ( aty_id, + av.V.field_values, + generics.regions, + generics.types, + generics.const_generics ) + with (* Box *) - | T.Box, [ inner_value ], [], [ inner_ty ], [] - | T.Option, [ inner_value ], [], [ inner_ty ], [] -> + | T.Box, [ inner_value ], [], [ inner_ty ], [] -> assert (inner_value.V.ty = inner_ty) - | T.Option, _, [], [ _ ], [] -> - (* Option::None: nothing to check *) - () - | T.Vec, fvs, [], [ vec_ty ], [] -> - List.iter - (fun (v : V.typed_value) -> assert (v.ty = vec_ty)) - fvs - | T.Range, [ v0; v1 ], [], [ inner_ty ], [] -> - assert (v0.V.ty = inner_ty); - assert (v1.V.ty = inner_ty) | T.Array, inner_values, _, [ inner_ty ], [ cg ] -> (* *) assert ( @@ -520,14 +519,17 @@ let check_typing_invariant (ctx : C.eval_ctx) : unit = (* Check the current pair (value, type) *) (match (atv.V.value, atv.V.ty) with (* ADT case *) - | V.AAdt av, T.Adt (T.AdtId def_id, regions, tys, cgs) -> + | V.AAdt av, T.Adt (T.AdtId def_id, generics) -> (* Retrieve the definition to check the variant id, the number of * parameters, etc. *) let def = C.ctx_lookup_type_decl ctx def_id in (* Check the number of parameters *) - assert (List.length regions = List.length def.region_params); - assert (List.length tys = List.length def.type_params); - assert (List.length cgs = List.length def.const_generic_params); + assert ( + List.length generics.regions = List.length def.generics.regions); + assert (List.length generics.types = List.length def.generics.types); + assert ( + List.length generics.const_generics + = List.length def.generics.const_generics); (* Check that the variant id is consistent *) (match (av.V.variant_id, def.T.kind) with | Some variant_id, T.Enum variants -> @@ -536,8 +538,8 @@ let check_typing_invariant (ctx : C.eval_ctx) : unit = | _ -> raise (Failure "Erroneous typing")); (* Check that the field types are correct *) let field_types = - Subst.type_decl_get_instantiated_field_rtypes def av.V.variant_id - regions tys cgs + Assoc.type_decl_get_inst_norm_field_rtypes ctx def av.V.variant_id + generics in let fields_with_types = List.combine av.V.field_values field_types @@ -546,20 +548,28 @@ let check_typing_invariant (ctx : C.eval_ctx) : unit = (fun ((v, ty) : V.typed_avalue * T.rty) -> assert (v.V.ty = ty)) fields_with_types (* Tuple case *) - | V.AAdt av, T.Adt (T.Tuple, regions, tys, cgs) -> - assert (regions = []); - assert (cgs = []); + | V.AAdt av, T.Adt (T.Tuple, generics) -> + assert (generics.regions = []); + assert (generics.const_generics = []); assert (av.V.variant_id = None); (* Check that the fields have the proper values - and check that there * are as many fields as field types at the same time *) - let fields_with_types = List.combine av.V.field_values tys in + let fields_with_types = + List.combine av.V.field_values generics.types + in List.iter (fun ((v, ty) : V.typed_avalue * T.rty) -> assert (v.V.ty = ty)) fields_with_types (* Assumed type case *) - | V.AAdt av, T.Adt (T.Assumed aty_id, regions, tys, cgs) -> ( + | V.AAdt av, T.Adt (T.Assumed aty_id, generics) -> ( assert (av.V.variant_id = None); - match (aty_id, av.V.field_values, regions, tys, cgs) with + match + ( aty_id, + av.V.field_values, + generics.regions, + generics.types, + generics.const_generics ) + with (* Box *) | T.Box, [ boxed_value ], [], [ boxed_ty ], [] -> assert (boxed_value.V.ty = boxed_ty) diff --git a/compiler/LlbcAst.ml b/compiler/LlbcAst.ml index f4d26e18..2db859b2 100644 --- a/compiler/LlbcAst.ml +++ b/compiler/LlbcAst.ml @@ -11,6 +11,7 @@ type abs_region_groups = (AbstractionId.id, RegionId.id) g_region_groups (** A function signature, after instantiation *) type inst_fun_sig = { regions_hierarchy : abs_region_groups; + trait_type_constraints : rtrait_type_constraint list; inputs : rty list; output : rty; } diff --git a/compiler/LlbcAstUtils.ml b/compiler/LlbcAstUtils.ml index 1111c297..0ab4ed94 100644 --- a/compiler/LlbcAstUtils.ml +++ b/compiler/LlbcAstUtils.ml @@ -5,10 +5,46 @@ let lookup_fun_sig (fun_id : fun_id) (fun_decls : fun_decl FunDeclId.Map.t) : fun_sig = match fun_id with | Regular id -> (FunDeclId.Map.find id fun_decls).signature - | Assumed aid -> Assumed.get_assumed_sig aid + | Assumed aid -> Assumed.get_assumed_fun_sig aid let lookup_fun_name (fun_id : fun_id) (fun_decls : fun_decl FunDeclId.Map.t) : Names.fun_name = match fun_id with | Regular id -> (FunDeclId.Map.find id fun_decls).name - | Assumed aid -> Assumed.get_assumed_name aid + | Assumed aid -> Assumed.get_assumed_fun_name aid + +(** Return the opaque declarations found in the crate, which are also *not builtin*. + + [filter_assumed]: if [true], do not consider as opaque the external definitions + that we will map to definitions from the standard library. + + Remark: the list of functions also contains the list of opaque global bodies. + *) +let crate_get_opaque_non_builtin_decls (k : crate) (filter_assumed : bool) : + T.type_decl list * fun_decl list = + let open ExtractBuiltin in + let is_opaque_fun (d : fun_decl) : bool = + let sname = name_to_simple_name d.name in + d.body = None + (* Something to pay attention to: we must ignore trait method *declarations* + (which don't have a body but must not be considered as opaque) *) + && (match d.kind with TraitMethodDecl _ -> false | _ -> true) + && ((not filter_assumed) + || (not (SimpleNameMap.mem sname builtin_globals_map)) + && not (SimpleNameMap.mem sname (builtin_funs_map ()))) + in + let is_opaque_type (d : T.type_decl) : bool = + let sname = name_to_simple_name d.name in + d.kind = T.Opaque + && ((not filter_assumed) + || not (SimpleNameMap.mem sname (builtin_types_map ()))) + in + (* Note that by checking the function bodies we also the globals *) + ( List.filter is_opaque_type (T.TypeDeclId.Map.values k.types), + List.filter is_opaque_fun (FunDeclId.Map.values k.functions) ) + +(** Return true if the crate contains opaque declarations, ignoring the assumed + definitions. *) +let crate_has_opaque_non_builtin_decls (k : crate) (filter_assumed : bool) : + bool = + crate_get_opaque_non_builtin_decls k filter_assumed <> ([], []) diff --git a/compiler/Logging.ml b/compiler/Logging.ml index 9dc1f5e3..721655b8 100644 --- a/compiler/Logging.ml +++ b/compiler/Logging.ml @@ -9,6 +9,9 @@ let pre_passes_log = L.get_logger "MainLogger.PrePasses" (** Logger for Translate *) let translate_log = L.get_logger "MainLogger.Translate" +(** Logger for Contexts *) +let contexts_log = L.get_logger "MainLogger.Contexts" + (** Logger for PureUtils *) let pure_utils_log = L.get_logger "MainLogger.PureUtils" @@ -19,7 +22,7 @@ let symbolic_to_pure_log = L.get_logger "MainLogger.SymbolicToPure" let pure_micro_passes_log = L.get_logger "MainLogger.PureMicroPasses" (** Logger for ExtractBase *) -let pure_to_extract_log = L.get_logger "MainLogger.ExtractBase" +let extract_log = L.get_logger "MainLogger.ExtractBase" (** Logger for Interpreter *) let interpreter_log = L.get_logger "MainLogger.Interpreter" @@ -57,6 +60,9 @@ let borrows_log = L.get_logger "MainLogger.Interpreter.Borrows" (** Logger for Invariants *) let invariants_log = L.get_logger "MainLogger.Interpreter.Invariants" +(** Logger for AssociatedTypes *) +let associated_types_log = L.get_logger "MainLogger.AssociatedTypes" + (** Logger for SCC *) let scc_log = L.get_logger "MainLogger.Graph.SCC" diff --git a/compiler/PrePasses.ml b/compiler/PrePasses.ml index b348ba1d..ee06fa07 100644 --- a/compiler/PrePasses.ml +++ b/compiler/PrePasses.ml @@ -107,8 +107,8 @@ let remove_useless_cf_merges (crate : A.crate) (f : A.fun_decl) : A.fun_decl = false | Assign (_, rv) -> ( match rv with - | Use _ | Ref _ -> not must_end_with_exit - | Aggregate (AggregatedTuple, []) -> not must_end_with_exit + | Use _ | RvRef _ -> not must_end_with_exit + | Aggregate (AggregatedAdt (Tuple, _, _), []) -> not must_end_with_exit | _ -> false) | FakeRead _ | Drop _ | Nop -> not must_end_with_exit | Panic | Return -> true @@ -376,7 +376,7 @@ let remove_shallow_borrows (crate : A.crate) (f : A.fun_decl) : A.fun_decl = method! visit_Assign env p rv = match (p.projection, rv) with - | [], E.Ref (_, E.Shallow) -> + | [], E.RvRef (_, E.Shallow) -> (* Filter *) filtered := E.VarId.Set.add p.var_id !filtered; Nop diff --git a/compiler/Print.ml b/compiler/Print.ml index 9aa73d7c..7f0d95ff 100644 --- a/compiler/Print.ml +++ b/compiler/Print.ml @@ -21,6 +21,9 @@ module Values = struct type_decl_id_to_string : T.TypeDeclId.id -> string; const_generic_var_id_to_string : T.ConstGenericVarId.id -> string; global_decl_id_to_string : T.GlobalDeclId.id -> string; + trait_decl_id_to_string : T.TraitDeclId.id -> string; + trait_impl_id_to_string : T.TraitImplId.id -> string; + trait_clause_id_to_string : T.TraitClauseId.id -> string; adt_variant_to_string : T.TypeDeclId.id -> T.VariantId.id -> string; var_id_to_string : E.VarId.id -> string; adt_field_names : @@ -34,6 +37,9 @@ module Values = struct PT.type_decl_id_to_string = fmt.type_decl_id_to_string; PT.const_generic_var_id_to_string = fmt.const_generic_var_id_to_string; PT.global_decl_id_to_string = fmt.global_decl_id_to_string; + PT.trait_decl_id_to_string = fmt.trait_decl_id_to_string; + PT.trait_impl_id_to_string = fmt.trait_impl_id_to_string; + PT.trait_clause_id_to_string = fmt.trait_clause_id_to_string; } let value_to_rtype_formatter (fmt : value_formatter) : PT.rtype_formatter = @@ -43,6 +49,9 @@ module Values = struct PT.type_decl_id_to_string = fmt.type_decl_id_to_string; PT.const_generic_var_id_to_string = fmt.const_generic_var_id_to_string; PT.global_decl_id_to_string = fmt.global_decl_id_to_string; + PT.trait_decl_id_to_string = fmt.trait_decl_id_to_string; + PT.trait_impl_id_to_string = fmt.trait_impl_id_to_string; + PT.trait_clause_id_to_string = fmt.trait_clause_id_to_string; } let value_to_stype_formatter (fmt : value_formatter) : PT.stype_formatter = @@ -52,6 +61,9 @@ module Values = struct PT.type_decl_id_to_string = fmt.type_decl_id_to_string; PT.const_generic_var_id_to_string = fmt.const_generic_var_id_to_string; PT.global_decl_id_to_string = fmt.global_decl_id_to_string; + PT.trait_decl_id_to_string = fmt.trait_decl_id_to_string; + PT.trait_impl_id_to_string = fmt.trait_impl_id_to_string; + PT.trait_clause_id_to_string = fmt.trait_clause_id_to_string; } let var_id_to_string (id : E.VarId.id) : string = @@ -86,10 +98,10 @@ module Values = struct List.map (typed_value_to_string fmt) av.field_values in match v.ty with - | T.Adt (T.Tuple, _, _, _) -> + | T.Adt (T.Tuple, _) -> (* Tuple *) "(" ^ String.concat ", " field_values ^ ")" - | T.Adt (T.AdtId def_id, _, _, _) -> + | T.Adt (T.AdtId def_id, _) -> (* "Regular" ADT *) let adt_ident = match av.variant_id with @@ -111,21 +123,10 @@ module Values = struct let field_values = String.concat " " field_values in adt_ident ^ " { " ^ field_values ^ " }" else adt_ident - | T.Adt (T.Assumed aty, _, _, _) -> ( + | T.Adt (T.Assumed aty, _) -> ( (* Assumed type *) match (aty, field_values) with | Box, [ bv ] -> "@Box(" ^ bv ^ ")" - | Option, _ -> - if av.variant_id = Some T.option_some_id then - "@Option::Some(" - ^ Collections.List.to_cons_nil field_values - ^ ")" - else if av.variant_id = Some T.option_none_id then ( - assert (field_values = []); - "@Option::None") - else raise (Failure "Unreachable") - | Range, _ -> "@Range{ " ^ String.concat ", " field_values ^ "}" - | Vec, _ -> "@Vec[" ^ String.concat ", " field_values ^ "]" | Array, _ -> (* Happens when we aggregate values *) "@Array[" ^ String.concat ", " field_values ^ "]" @@ -201,10 +202,10 @@ module Values = struct List.map (typed_avalue_to_string fmt) av.field_values in match v.ty with - | T.Adt (T.Tuple, _, _, _) -> + | T.Adt (T.Tuple, _) -> (* Tuple *) "(" ^ String.concat ", " field_values ^ ")" - | T.Adt (T.AdtId def_id, _, _, _) -> + | T.Adt (T.AdtId def_id, _) -> (* "Regular" ADT *) let adt_ident = match av.variant_id with @@ -226,7 +227,7 @@ module Values = struct let field_values = String.concat " " field_values in adt_ident ^ " { " ^ field_values ^ " }" else adt_ident - | T.Adt (T.Assumed aty, _, _, _) -> ( + | T.Adt (T.Assumed aty, _) -> ( (* Assumed type *) match (aty, field_values) with | Box, [ bv ] -> "@Box(" ^ bv ^ ")" @@ -347,6 +348,18 @@ module Values = struct ^ "}" ^ "{regions=" ^ T.RegionId.Set.to_string None abs.regions ^ "}" ^ " {\n" ^ avs ^ "\n" ^ indent ^ "}" + + let inst_fun_sig_to_string (fmt : value_formatter) (sg : LlbcAst.inst_fun_sig) + : string = + (* TODO: print the trait type constraints? *) + let ty_fmt = value_to_rtype_formatter fmt in + let ty_to_string = PT.ty_to_string ty_fmt in + + let inputs = + "(" ^ String.concat ", " (List.map ty_to_string sg.inputs) ^ ")" + in + let output = ty_to_string sg.output in + inputs ^ " -> " ^ output end module PV = Values (* local module *) @@ -452,6 +465,9 @@ module Contexts = struct PV.adt_variant_to_string = fmt.adt_variant_to_string; PV.var_id_to_string = fmt.var_id_to_string; PV.adt_field_names = fmt.adt_field_names; + PV.trait_decl_id_to_string = fmt.trait_decl_id_to_string; + PV.trait_impl_id_to_string = fmt.trait_impl_id_to_string; + PV.trait_clause_id_to_string = fmt.trait_clause_id_to_string; } let ast_to_value_formatter (fmt : PA.ast_formatter) : PV.value_formatter = @@ -463,20 +479,27 @@ module Contexts = struct let ctx_to_rtype_formatter (fmt : ctx_formatter) : PT.rtype_formatter = PV.value_to_rtype_formatter fmt + let ctx_to_stype_formatter (fmt : ctx_formatter) : PT.stype_formatter = + PV.value_to_stype_formatter fmt + let eval_ctx_to_ctx_formatter (ctx : C.eval_ctx) : ctx_formatter = - (* We shouldn't use rvar_to_string *) - let rvar_to_string _r = - raise (Failure "Unexpected use of rvar_to_string") + let rvar_to_string r = + (* In theory we shouldn't use rvar_to_string, but it can happen + when printing definitions for instance... *) + T.RegionVarId.to_string r in let r_to_string r = PT.region_id_to_string r in let type_var_id_to_string vid = - let v = C.lookup_type_var ctx vid in - v.name + (* The context may be invalid *) + match C.lookup_type_var_opt ctx vid with + | None -> T.TypeVarId.to_string vid + | Some v -> v.name in let const_generic_var_id_to_string vid = - let v = C.lookup_const_generic_var ctx vid in - v.name + match C.lookup_const_generic_var_opt ctx vid with + | None -> T.ConstGenericVarId.to_string vid + | Some v -> v.name in let type_decl_id_to_string def_id = let def = C.ctx_lookup_type_decl ctx def_id in @@ -486,6 +509,15 @@ module Contexts = struct let def = C.ctx_lookup_global_decl ctx def_id in name_to_string def.name in + let trait_decl_id_to_string def_id = + let def = C.ctx_lookup_trait_decl ctx def_id in + name_to_string def.name + in + let trait_impl_id_to_string def_id = + let def = C.ctx_lookup_trait_impl ctx def_id in + name_to_string def.name + in + let trait_clause_id_to_string id = PT.trait_clause_id_to_pretty_string id in let adt_variant_to_string = PT.type_ctx_to_adt_variant_to_string_fun ctx.type_context.type_decls in @@ -506,6 +538,9 @@ module Contexts = struct adt_variant_to_string; var_id_to_string; adt_field_names; + trait_decl_id_to_string; + trait_impl_id_to_string; + trait_clause_id_to_string; } let eval_ctx_to_ast_formatter (ctx : C.eval_ctx) : PA.ast_formatter = @@ -521,6 +556,15 @@ module Contexts = struct let def = C.ctx_lookup_global_decl ctx def_id in global_name_to_string def.name in + let trait_decl_id_to_string def_id = + let def = C.ctx_lookup_trait_decl ctx def_id in + name_to_string def.name + in + let trait_impl_id_to_string def_id = + let def = C.ctx_lookup_trait_impl ctx def_id in + name_to_string def.name + in + let trait_clause_id_to_string id = PT.trait_clause_id_to_pretty_string id in { rvar_to_string = ctx_fmt.PV.rvar_to_string; r_to_string = ctx_fmt.PV.r_to_string; @@ -533,6 +577,9 @@ module Contexts = struct adt_field_to_string; fun_decl_id_to_string; global_decl_id_to_string; + trait_decl_id_to_string; + trait_impl_id_to_string; + trait_clause_id_to_string; } (** Split an [env] at every occurrence of [Frame], eliminating those elements. @@ -608,6 +655,68 @@ module EvalCtxLlbcAst = struct let fmt = PC.ctx_to_rtype_formatter fmt in PT.rty_to_string fmt t + let sty_to_string (ctx : C.eval_ctx) (t : T.sty) : string = + let fmt = PC.eval_ctx_to_ctx_formatter ctx in + let fmt = PC.ctx_to_stype_formatter fmt in + PT.sty_to_string fmt t + + let generic_params_to_strings (ctx : C.eval_ctx) (x : T.generic_params) : + string list * string list = + let fmt = PC.eval_ctx_to_ctx_formatter ctx in + let fmt = PC.ctx_to_stype_formatter fmt in + PT.generic_params_to_strings fmt x + + let egeneric_args_to_string (ctx : C.eval_ctx) (x : T.egeneric_args) : string + = + let fmt = PC.eval_ctx_to_ctx_formatter ctx in + let fmt = PC.ctx_to_etype_formatter fmt in + PT.egeneric_args_to_string fmt x + + let rgeneric_args_to_string (ctx : C.eval_ctx) (x : T.rgeneric_args) : string + = + let fmt = PC.eval_ctx_to_ctx_formatter ctx in + let fmt = PC.ctx_to_rtype_formatter fmt in + PT.rgeneric_args_to_string fmt x + + let sgeneric_args_to_string (ctx : C.eval_ctx) (x : T.sgeneric_args) : string + = + let fmt = PC.eval_ctx_to_ctx_formatter ctx in + let fmt = PC.ctx_to_stype_formatter fmt in + PT.sgeneric_args_to_string fmt x + + let etrait_ref_to_string (ctx : C.eval_ctx) (x : T.etrait_ref) : string = + let fmt = PC.eval_ctx_to_ctx_formatter ctx in + let fmt = PC.ctx_to_etype_formatter fmt in + PT.etrait_ref_to_string fmt x + + let rtrait_ref_to_string (ctx : C.eval_ctx) (x : T.rtrait_ref) : string = + let fmt = PC.eval_ctx_to_ctx_formatter ctx in + let fmt = PC.ctx_to_rtype_formatter fmt in + PT.rtrait_ref_to_string fmt x + + let strait_ref_to_string (ctx : C.eval_ctx) (x : T.strait_ref) : string = + let fmt = PC.eval_ctx_to_ctx_formatter ctx in + let fmt = PC.ctx_to_stype_formatter fmt in + PT.strait_ref_to_string fmt x + + let etrait_instance_id_to_string (ctx : C.eval_ctx) (x : T.etrait_instance_id) + : string = + let fmt = PC.eval_ctx_to_ctx_formatter ctx in + let fmt = PC.ctx_to_etype_formatter fmt in + PT.etrait_instance_id_to_string fmt x + + let rtrait_instance_id_to_string (ctx : C.eval_ctx) (x : T.rtrait_instance_id) + : string = + let fmt = PC.eval_ctx_to_ctx_formatter ctx in + let fmt = PC.ctx_to_rtype_formatter fmt in + PT.rtrait_instance_id_to_string fmt x + + let strait_instance_id_to_string (ctx : C.eval_ctx) (x : T.strait_instance_id) + : string = + let fmt = PC.eval_ctx_to_ctx_formatter ctx in + let fmt = PC.ctx_to_stype_formatter fmt in + PT.strait_instance_id_to_string fmt x + let borrow_content_to_string (ctx : C.eval_ctx) (bc : V.borrow_content) : string = let fmt = PC.eval_ctx_to_ctx_formatter ctx in @@ -653,11 +762,38 @@ module EvalCtxLlbcAst = struct let fmt = PC.eval_ctx_to_ast_formatter ctx in PE.operand_to_string fmt op + let call_to_string (ctx : C.eval_ctx) (call : A.call) : string = + let fmt = PC.eval_ctx_to_ast_formatter ctx in + PA.call_to_string fmt "" call + + let fun_decl_to_string (ctx : C.eval_ctx) (f : A.fun_decl) : string = + let fmt = PC.eval_ctx_to_ast_formatter ctx in + PA.fun_decl_to_string fmt "" " " f + + let fun_sig_to_string (ctx : C.eval_ctx) (x : A.fun_sig) : string = + let fmt = PC.eval_ctx_to_ast_formatter ctx in + PA.fun_sig_to_string fmt "" " " x + + let inst_fun_sig_to_string (ctx : C.eval_ctx) (x : LlbcAst.inst_fun_sig) : + string = + let fmt = PC.eval_ctx_to_ast_formatter ctx in + let fmt = PC.ast_to_value_formatter fmt in + PV.inst_fun_sig_to_string fmt x + + let fun_id_or_trait_method_ref_to_string (ctx : C.eval_ctx) + (x : E.fun_id_or_trait_method_ref) : string = + let fmt = PC.eval_ctx_to_ast_formatter ctx in + PE.fun_id_or_trait_method_ref_to_string fmt x "..." + let statement_to_string (ctx : C.eval_ctx) (indent : string) (indent_incr : string) (e : A.statement) : string = let fmt = PC.eval_ctx_to_ast_formatter ctx in PA.statement_to_string fmt indent indent_incr e + let trait_impl_to_string (ctx : C.eval_ctx) (timpl : A.trait_impl) : string = + let fmt = PC.eval_ctx_to_ast_formatter ctx in + PA.trait_impl_to_string fmt " " " " timpl + let env_elem_to_string (ctx : C.eval_ctx) (indent : string) (indent_incr : string) (ev : C.env_elem) : string = let fmt = PC.eval_ctx_to_ctx_formatter ctx in diff --git a/compiler/PrintPure.ml b/compiler/PrintPure.ml index cfb63ec2..ec75fcfd 100644 --- a/compiler/PrintPure.ml +++ b/compiler/PrintPure.ml @@ -8,6 +8,9 @@ type type_formatter = { type_decl_id_to_string : TypeDeclId.id -> string; const_generic_var_id_to_string : ConstGenericVarId.id -> string; global_decl_id_to_string : GlobalDeclId.id -> string; + trait_decl_id_to_string : TraitDeclId.id -> string; + trait_impl_id_to_string : TraitImplId.id -> string; + trait_clause_id_to_string : TraitClauseId.id -> string; } type value_formatter = { @@ -18,6 +21,9 @@ type value_formatter = { adt_variant_to_string : TypeDeclId.id -> VariantId.id -> string; var_id_to_string : VarId.id -> string; adt_field_names : TypeDeclId.id -> VariantId.id option -> string list option; + trait_decl_id_to_string : TraitDeclId.id -> string; + trait_impl_id_to_string : TraitImplId.id -> string; + trait_clause_id_to_string : TraitClauseId.id -> string; } let value_to_type_formatter (fmt : value_formatter) : type_formatter = @@ -26,6 +32,9 @@ let value_to_type_formatter (fmt : value_formatter) : type_formatter = type_decl_id_to_string = fmt.type_decl_id_to_string; const_generic_var_id_to_string = fmt.const_generic_var_id_to_string; global_decl_id_to_string = fmt.global_decl_id_to_string; + trait_decl_id_to_string = fmt.trait_decl_id_to_string; + trait_impl_id_to_string = fmt.trait_impl_id_to_string; + trait_clause_id_to_string = fmt.trait_clause_id_to_string; } (* TODO: we need to store which variables we have encountered so far, and @@ -42,6 +51,9 @@ type ast_formatter = { adt_field_names : TypeDeclId.id -> VariantId.id option -> string list option; fun_decl_id_to_string : FunDeclId.id -> string; global_decl_id_to_string : GlobalDeclId.id -> string; + trait_decl_id_to_string : TraitDeclId.id -> string; + trait_impl_id_to_string : TraitImplId.id -> string; + trait_clause_id_to_string : TraitClauseId.id -> string; } let ast_to_value_formatter (fmt : ast_formatter) : value_formatter = @@ -53,6 +65,9 @@ let ast_to_value_formatter (fmt : ast_formatter) : value_formatter = adt_variant_to_string = fmt.adt_variant_to_string; var_id_to_string = fmt.var_id_to_string; adt_field_names = fmt.adt_field_names; + trait_decl_id_to_string = fmt.trait_decl_id_to_string; + trait_impl_id_to_string = fmt.trait_impl_id_to_string; + trait_clause_id_to_string = fmt.trait_clause_id_to_string; } let ast_to_type_formatter (fmt : ast_formatter) : type_formatter = @@ -70,31 +85,51 @@ let literal_type_to_string = Print.PrimitiveValues.literal_type_to_string let scalar_value_to_string = Print.PrimitiveValues.scalar_value_to_string let literal_to_string = Print.PrimitiveValues.literal_to_string +(* Remark: not using generic_params on purpose, because we may use parameters + which either come from LLBC or from pure, and the [generic_params] type + for those ASTs is not the same. Note that it works because we actually don't + need to know the trait clauses to print the AST: we can thus ignore them. +*) let mk_type_formatter (type_decls : T.type_decl TypeDeclId.Map.t) (global_decls : A.global_decl GlobalDeclId.Map.t) - (type_params : type_var list) + (trait_decls : A.trait_decl TraitDeclId.Map.t) + (trait_impls : A.trait_impl TraitImplId.Map.t) (type_params : type_var list) (const_generic_params : const_generic_var list) : type_formatter = let type_var_id_to_string vid = - let var = T.TypeVarId.nth type_params vid in + let var = TypeVarId.nth type_params vid in type_var_to_string var in let const_generic_var_id_to_string vid = - let var = T.ConstGenericVarId.nth const_generic_params vid in + let var = ConstGenericVarId.nth const_generic_params vid in const_generic_var_to_string var in let type_decl_id_to_string def_id = - let def = T.TypeDeclId.Map.find def_id type_decls in + let def = TypeDeclId.Map.find def_id type_decls in name_to_string def.name in let global_decl_id_to_string def_id = - let def = T.GlobalDeclId.Map.find def_id global_decls in + let def = GlobalDeclId.Map.find def_id global_decls in + name_to_string def.name + in + let trait_decl_id_to_string def_id = + let def = TraitDeclId.Map.find def_id trait_decls in + name_to_string def.name + in + let trait_impl_id_to_string def_id = + let def = TraitImplId.Map.find def_id trait_impls in name_to_string def.name in + let trait_clause_id_to_string id = + Print.PT.trait_clause_id_to_pretty_string id + in { type_var_id_to_string; type_decl_id_to_string; const_generic_var_id_to_string; global_decl_id_to_string; + trait_decl_id_to_string; + trait_impl_id_to_string; + trait_clause_id_to_string; } (* TODO: there is a bit of duplication with Print.fun_decl_to_ast_formatter. @@ -106,19 +141,21 @@ let mk_type_formatter (type_decls : T.type_decl TypeDeclId.Map.t) let mk_ast_formatter (type_decls : T.type_decl TypeDeclId.Map.t) (fun_decls : A.fun_decl FunDeclId.Map.t) (global_decls : A.global_decl GlobalDeclId.Map.t) - (type_params : type_var list) + (trait_decls : A.trait_decl TraitDeclId.Map.t) + (trait_impls : A.trait_impl TraitImplId.Map.t) (type_params : type_var list) (const_generic_params : const_generic_var list) : ast_formatter = - let type_var_id_to_string vid = - let var = T.TypeVarId.nth type_params vid in - type_var_to_string var - in - let const_generic_var_id_to_string vid = - let var = T.ConstGenericVarId.nth const_generic_params vid in - const_generic_var_to_string var - in - let type_decl_id_to_string def_id = - let def = T.TypeDeclId.Map.find def_id type_decls in - name_to_string def.name + let ({ + type_var_id_to_string; + type_decl_id_to_string; + const_generic_var_id_to_string; + global_decl_id_to_string; + trait_decl_id_to_string; + trait_impl_id_to_string; + trait_clause_id_to_string; + } + : type_formatter) = + mk_type_formatter type_decls global_decls trait_decls trait_impls + type_params const_generic_params in let adt_variant_to_string = Print.Types.type_ctx_to_adt_variant_to_string_fun type_decls @@ -137,10 +174,6 @@ let mk_ast_formatter (type_decls : T.type_decl TypeDeclId.Map.t) let def = FunDeclId.Map.find def_id fun_decls in fun_name_to_string def.name in - let global_decl_id_to_string def_id = - let def = GlobalDeclId.Map.find def_id global_decls in - global_name_to_string def.name - in { type_var_id_to_string; const_generic_var_id_to_string; @@ -151,6 +184,9 @@ let mk_ast_formatter (type_decls : T.type_decl TypeDeclId.Map.t) adt_field_to_string; fun_decl_id_to_string; global_decl_id_to_string; + trait_decl_id_to_string; + trait_impl_id_to_string; + trait_clause_id_to_string; } let assumed_ty_to_string (aty : assumed_ty) : string = @@ -159,12 +195,11 @@ let assumed_ty_to_string (aty : assumed_ty) : string = | Result -> "Result" | Error -> "Error" | Fuel -> "Fuel" - | Option -> "Option" - | Vec -> "Vec" | Array -> "Array" | Slice -> "Slice" | Str -> "Str" - | Range -> "Range" + | RawPtr Mut -> "MutRawPtr" + | RawPtr Const -> "ConstRawPtr" let type_id_to_string (fmt : type_formatter) (id : type_id) : string = match id with @@ -182,20 +217,18 @@ let const_generic_to_string (fmt : type_formatter) (cg : T.const_generic) : let rec ty_to_string (fmt : type_formatter) (inside : bool) (ty : ty) : string = match ty with - | Adt (id, tys, cgs) -> ( - let tys = List.map (ty_to_string fmt false) tys in - let cgs = List.map (const_generic_to_string fmt) cgs in - let params = List.append tys cgs in + | Adt (id, generics) -> ( match id with | Tuple -> - assert (cgs = []); - "(" ^ String.concat " * " tys ^ ")" + let generics = generic_args_to_strings fmt false generics in + "(" ^ String.concat " * " generics ^ ")" | AdtId _ | Assumed _ -> - let params_s = - if params = [] then "" else " " ^ String.concat " " params + let generics = generic_args_to_strings fmt true generics in + let generics_s = + if generics = [] then "" else " " ^ String.concat " " generics in - let ty_s = type_id_to_string fmt id ^ params_s in - if params <> [] && inside then "(" ^ ty_s ^ ")" else ty_s) + let ty_s = type_id_to_string fmt id ^ generics_s in + if generics <> [] && inside then "(" ^ ty_s ^ ")" else ty_s) | TypeVar tv -> fmt.type_var_id_to_string tv | Literal lty -> literal_type_to_string lty | Arrow (arg_ty, ret_ty) -> @@ -203,6 +236,71 @@ let rec ty_to_string (fmt : type_formatter) (inside : bool) (ty : ty) : string = ty_to_string fmt true arg_ty ^ " -> " ^ ty_to_string fmt false ret_ty in if inside then "(" ^ ty ^ ")" else ty + | TraitType (trait_ref, generics, type_name) -> + let trait_ref = trait_ref_to_string fmt false trait_ref in + let s = + if generics = empty_generic_args then trait_ref ^ "::" ^ type_name + else + let generics = generic_args_to_string fmt generics in + "(" ^ trait_ref ^ " " ^ generics ^ ")::" ^ type_name + in + if inside then "(" ^ s ^ ")" else s + +and generic_args_to_strings (fmt : type_formatter) (inside : bool) + (generics : generic_args) : string list = + let tys = List.map (ty_to_string fmt inside) generics.types in + let cgs = List.map (const_generic_to_string fmt) generics.const_generics in + let trait_refs = + List.map (trait_ref_to_string fmt inside) generics.trait_refs + in + List.concat [ tys; cgs; trait_refs ] + +and generic_args_to_string (fmt : type_formatter) (generics : generic_args) : + string = + String.concat " " (generic_args_to_strings fmt true generics) + +and trait_ref_to_string (fmt : type_formatter) (inside : bool) (tr : trait_ref) + : string = + let trait_id = trait_instance_id_to_string fmt false tr.trait_id in + let generics = generic_args_to_string fmt tr.generics in + let s = trait_id ^ generics in + if tr.generics = empty_generic_args || not inside then s else "(" ^ s ^ ")" + +and trait_instance_id_to_string (fmt : type_formatter) (inside : bool) + (id : trait_instance_id) : string = + match id with + | Self -> "Self" + | TraitImpl id -> fmt.trait_impl_id_to_string id + | Clause id -> fmt.trait_clause_id_to_string id + | ParentClause (inst_id, _decl_id, clause_id) -> + let inst_id = trait_instance_id_to_string fmt false inst_id in + let clause_id = fmt.trait_clause_id_to_string clause_id in + "parent(" ^ inst_id ^ ")::" ^ clause_id + | ItemClause (inst_id, _decl_id, item_name, clause_id) -> + let inst_id = trait_instance_id_to_string fmt false inst_id in + let clause_id = fmt.trait_clause_id_to_string clause_id in + "(" ^ inst_id ^ ")::" ^ item_name ^ "::[" ^ clause_id ^ "]" + | TraitRef tr -> trait_ref_to_string fmt inside tr + | UnknownTrait msg -> "UNKNOWN(" ^ msg ^ ")" + +let trait_clause_to_string (fmt : type_formatter) (clause : trait_clause) : + string = + let clause_id = fmt.trait_clause_id_to_string clause.clause_id in + let trait_id = fmt.trait_decl_id_to_string clause.trait_id in + let generics = generic_args_to_strings fmt true clause.generics in + let generics = + if generics = [] then "" else " " ^ String.concat " " generics + in + "[" ^ clause_id ^ "]: " ^ trait_id ^ generics + +let generic_params_to_strings (fmt : type_formatter) (generics : generic_params) + : string list = + let tys = List.map type_var_to_string generics.types in + let cgs = List.map const_generic_var_to_string generics.const_generics in + let trait_clauses = + List.map (trait_clause_to_string fmt) generics.trait_clauses + in + List.concat [ tys; cgs; trait_clauses ] let field_to_string fmt inside (f : field) : string = match f.field_name with @@ -217,11 +315,10 @@ let variant_to_string fmt (v : variant) : string = ^ ")" let type_decl_to_string (fmt : type_formatter) (def : type_decl) : string = - let types = def.type_params in let name = name_to_string def.name in let params = - if types = [] then "" - else " " ^ String.concat " " (List.map type_var_to_string types) + if def.generics = empty_generic_params then "" + else " " ^ String.concat " " (generic_params_to_strings fmt def.generics) in match def.kind with | Struct fields -> @@ -256,10 +353,6 @@ let rec mprojection_to_string (fmt : ast_formatter) (inside : string) | pe :: p' -> ( let s = mprojection_to_string fmt inside p' in match pe.pkind with - | E.ProjOption variant_id -> - assert (variant_id = T.option_some_id); - assert (pe.field_id = T.FieldId.zero); - "(" ^ s ^ "as Option::Some)." ^ T.FieldId.to_string pe.field_id | E.ProjTuple _ -> "(" ^ s ^ ")." ^ T.FieldId.to_string pe.field_id | E.ProjAdt (adt_id, opt_variant_id) -> ( let field_name = @@ -294,11 +387,9 @@ let adt_variant_to_string (fmt : value_formatter) (adt_id : type_id) | Assumed aty -> ( (* Assumed type *) match aty with - | State | Array | Slice | Str -> + | State | Array | Slice | Str | RawPtr _ -> (* Those types are opaque: we can't get there *) raise (Failure "Unreachable") - | Vec -> "@Vec" - | Range -> "@Range" | Result -> let variant_id = Option.get variant_id in if variant_id = result_return_id then "@Result::Return" @@ -314,13 +405,7 @@ let adt_variant_to_string (fmt : value_formatter) (adt_id : type_id) let variant_id = Option.get variant_id in if variant_id = fuel_zero_id then "@Fuel::Zero" else if variant_id = fuel_succ_id then "@Fuel::Succ" - else raise (Failure "Unreachable: improper variant id for fuel type") - | Option -> - let variant_id = Option.get variant_id in - if variant_id = option_some_id then "@Option::Some " - else if variant_id = option_none_id then "@Option::None" - else - raise (Failure "Unreachable: improper variant id for result type")) + else raise (Failure "Unreachable: improper variant id for fuel type")) let adt_field_to_string (fmt : value_formatter) (adt_id : type_id) (field_id : FieldId.id) : string = @@ -337,11 +422,10 @@ let adt_field_to_string (fmt : value_formatter) (adt_id : type_id) | Assumed aty -> ( (* Assumed type *) match aty with - | Range -> FieldId.to_string field_id - | State | Fuel | Vec | Array | Slice | Str -> + | State | Fuel | Array | Slice | Str -> (* Opaque types: we can't get there *) raise (Failure "Unreachable") - | Result | Error | Option -> + | Result | Error | RawPtr _ -> (* Enumerations: we can't get there *) raise (Failure "Unreachable")) @@ -353,10 +437,10 @@ let adt_g_value_to_string (fmt : value_formatter) (field_values : 'v list) (ty : ty) : string = let field_values = List.map value_to_string field_values in match ty with - | Adt (Tuple, _, _) -> + | Adt (Tuple, _) -> (* Tuple *) "(" ^ String.concat ", " field_values ^ ")" - | Adt (AdtId def_id, _, _) -> + | Adt (AdtId def_id, _) -> (* "Regular" ADT *) let adt_ident = match variant_id with @@ -378,10 +462,10 @@ let adt_g_value_to_string (fmt : value_formatter) let field_values = String.concat " " field_values in adt_ident ^ " { " ^ field_values ^ " }" else adt_ident - | Adt (Assumed aty, _, _) -> ( + | Adt (Assumed aty, _) -> ( (* Assumed type *) match aty with - | State -> + | State | RawPtr _ -> (* This type is opaque: we can't get there *) raise (Failure "Unreachable") | Result -> @@ -412,31 +496,13 @@ let adt_g_value_to_string (fmt : value_formatter) | [ v ] -> "@Fuel::Succ " ^ v | _ -> raise (Failure "@Fuel::Succ takes exactly one value") else raise (Failure "Unreachable: improper variant id for fuel type") - | Option -> - let variant_id = Option.get variant_id in - if variant_id = option_some_id then - match field_values with - | [ v ] -> "@Option::Some " ^ v - | _ -> raise (Failure "Option::Some takes exactly one value") - else if variant_id = option_none_id then ( - assert (field_values = []); - "@Option::None") - else - raise (Failure "Unreachable: improper variant id for result type") - | Vec | Array | Slice | Str -> + | Array | Slice | Str -> assert (variant_id = None); let field_values = List.mapi (fun i v -> string_of_int i ^ " -> " ^ v) field_values in let id = assumed_ty_to_string aty in - id ^ " [" ^ String.concat "; " field_values ^ "]" - | Range -> - assert (variant_id = None); - let field_values = - List.mapi (fun i v -> string_of_int i ^ " -> " ^ v) field_values - in - let id = assumed_ty_to_string aty in - id ^ " {" ^ String.concat "; " field_values ^ "}") + id ^ " [" ^ String.concat "; " field_values ^ "]") | _ -> let fmt = value_to_type_formatter fmt in raise @@ -464,10 +530,10 @@ let rec typed_pattern_to_string (fmt : ast_formatter) (v : typed_pattern) : let fun_sig_to_string (fmt : ast_formatter) (sg : fun_sig) : string = let ty_fmt = ast_to_type_formatter fmt in - let type_params = List.map type_var_to_string sg.type_params in + let generics = generic_params_to_strings ty_fmt sg.generics in let inputs = List.map (ty_to_string ty_fmt false) sg.inputs in let output = ty_to_string ty_fmt false sg.output in - let all_types = List.concat [ type_params; inputs; [ output ] ] in + let all_types = List.concat [ generics; inputs; [ output ] ] in String.concat " -> " all_types let inst_fun_sig_to_string (fmt : ast_formatter) (sg : inst_fun_sig) : string = @@ -495,28 +561,16 @@ let fun_suffix (lp_id : LoopId.id option) (rg_id : T.RegionGroupId.id option) : let llbc_assumed_fun_id_to_string (fid : A.assumed_fun_id) : string = match fid with - | A.Replace -> "core::mem::replace" - | A.BoxNew -> "alloc::boxed::Box::new" - | A.BoxDeref -> "core::ops::deref::Deref::deref" - | A.BoxDerefMut -> "core::ops::deref::DerefMut::deref_mut" - | A.BoxFree -> "alloc::alloc::box_free" - | A.VecNew -> "alloc::vec::Vec::new" - | A.VecPush -> "alloc::vec::Vec::push" - | A.VecInsert -> "alloc::vec::Vec::insert" - | A.VecLen -> "alloc::vec::Vec::len" - | A.VecIndex -> "core::ops::index::Index<alloc::vec::Vec>::index" - | A.VecIndexMut -> "core::ops::index::IndexMut<alloc::vec::Vec>::index_mut" + | BoxNew -> "alloc::boxed::Box::new" + | BoxFree -> "alloc::alloc::box_free" | ArrayIndexShared -> "@ArrayIndexShared" | ArrayIndexMut -> "@ArrayIndexMut" | ArrayToSliceShared -> "@ArrayToSliceShared" | ArrayToSliceMut -> "@ArrayToSliceMut" - | ArraySubsliceShared -> "@ArraySubsliceShared" - | ArraySubsliceMut -> "@ArraySubsliceMut" + | ArrayRepeat -> "@ArrayRepeat" | SliceLen -> "@SliceLen" | SliceIndexShared -> "@SliceIndexShared" | SliceIndexMut -> "@SliceIndexMut" - | SliceSubsliceShared -> "@SliceSubsliceShared" - | SliceSubsliceMut -> "@SliceSubsliceMut" let pure_assumed_fun_id_to_string (fid : pure_assumed_fun_id) : string = match fid with @@ -531,8 +585,11 @@ let regular_fun_id_to_string (fmt : ast_formatter) (fun_id : fun_id) : string = | FromLlbc (fid, lp_id, rg_id) -> let f = match fid with - | Regular fid -> fmt.fun_decl_id_to_string fid - | Assumed fid -> llbc_assumed_fun_id_to_string fid + | FunId (Regular fid) -> fmt.fun_decl_id_to_string fid + | FunId (Assumed fid) -> llbc_assumed_fun_id_to_string fid + | TraitMethod (trait_ref, method_name, _) -> + let fmt = ast_to_type_formatter fmt in + trait_ref_to_string fmt true trait_ref ^ "." ^ method_name in f ^ fun_suffix lp_id rg_id | Pure fid -> pure_assumed_fun_id_to_string fid @@ -559,9 +616,8 @@ let fun_or_op_id_to_string (fmt : ast_formatter) (fun_id : fun_or_op_id) : let rec texpression_to_string (fmt : ast_formatter) (inside : bool) (indent : string) (indent_incr : string) (e : texpression) : string = match e.e with - | Var var_id -> - let s = fmt.var_id_to_string var_id in - if inside then "(" ^ s ^ ")" else s + | Var var_id -> fmt.var_id_to_string var_id + | CVar cg_id -> fmt.const_generic_var_id_to_string cg_id | Const cv -> literal_to_string cv | App _ -> (* Recursively destruct the app, to have a pair (app, arguments list) *) @@ -632,10 +688,11 @@ and app_to_string (fmt : ast_formatter) (inside : bool) (indent : string) (* There are two possibilities: either the [app] is an instantiated, * top-level qualifier (function, ADT constructore...), or it is a "regular" * expression *) - let app, tys = + let app, generics = match app.e with | Qualif qualif -> (* Qualifier case *) + let ty_fmt = ast_to_type_formatter fmt in (* Convert the qualifier identifier *) let qualif_s = match qualif.id with @@ -654,12 +711,17 @@ and app_to_string (fmt : ast_formatter) (inside : bool) (indent : string) let field_s = adt_field_to_string value_fmt adt_id field_id in (* Adopting an F*-like syntax *) ConstStrings.constructor_prefix ^ adt_s ^ "?." ^ field_s + | TraitConst (trait_ref, generics, const_name) -> + let trait_ref = trait_ref_to_string ty_fmt true trait_ref in + let generics_s = generic_args_to_string ty_fmt generics in + if generics <> empty_generic_args then + "(" ^ trait_ref ^ generics_s ^ ")." ^ const_name + else trait_ref ^ "." ^ const_name in (* Convert the type instantiation *) - let ty_fmt = ast_to_type_formatter fmt in - let tys = List.map (ty_to_string ty_fmt true) qualif.type_args in + let generics = generic_args_to_strings ty_fmt true qualif.generics in (* *) - (qualif_s, tys) + (qualif_s, generics) | _ -> (* "Regular" expression case *) let inside = args <> [] || (args = [] && inside) in @@ -674,7 +736,7 @@ and app_to_string (fmt : ast_formatter) (inside : bool) (indent : string) texpression_to_string fmt inside indent1 indent_incr in let args = List.map arg_to_string args in - let all_args = List.append tys args in + let all_args = List.append generics args in (* Put together *) let e = if all_args = [] then app else app ^ " " ^ String.concat " " all_args diff --git a/compiler/Pure.ml b/compiler/Pure.ml index ac4ca081..e6a3dab5 100644 --- a/compiler/Pure.ml +++ b/compiler/Pure.ml @@ -13,6 +13,9 @@ module FieldId = T.FieldId module SymbolicValueId = V.SymbolicValueId module FunDeclId = A.FunDeclId module GlobalDeclId = A.GlobalDeclId +module TraitDeclId = T.TraitDeclId +module TraitImplId = T.TraitImplId +module TraitClauseId = T.TraitClauseId (** We redefine identifiers for loop: in {!Values}, the identifiers are global (they monotonically increase across functions) while in {!module:Pure} we want @@ -21,8 +24,6 @@ module GlobalDeclId = A.GlobalDeclId module LoopId = IdGen () -type loop_id = LoopId.id [@@deriving show, ord] - (** We give an identifier to every phase of the synthesis (forward, backward for group of regions 0, etc.) *) module SynthPhaseId = @@ -37,6 +38,16 @@ module ConstGenericVarId = T.ConstGenericVarId type integer_type = T.integer_type [@@deriving show, ord] type const_generic_var = T.const_generic_var [@@deriving show, ord] type const_generic = T.const_generic [@@deriving show, ord] +type const_generic_var_id = T.const_generic_var_id [@@deriving show, ord] +type trait_decl_id = T.trait_decl_id [@@deriving show, ord] +type trait_impl_id = T.trait_impl_id [@@deriving show, ord] +type trait_clause_id = T.trait_clause_id [@@deriving show, ord] +type trait_item_name = T.trait_item_name [@@deriving show, ord] +type global_decl_id = T.global_decl_id [@@deriving show, ord] +type fun_decl_id = A.fun_decl_id [@@deriving show, ord] +type loop_id = LoopId.id [@@deriving show, ord] +type region_group_id = T.region_group_id [@@deriving show, ord] +type mutability = Mut | Const [@@deriving show, ord] (** The assumed types for the pure AST. @@ -59,12 +70,17 @@ type assumed_ty = | Result | Error | Fuel - | Vec - | Option | Array | Slice | Str - | Range + | RawPtr of mutability + (** The bool + Raw pointers don't make sense in the pure world, but we don't know + how to translate them yet and we have to handle some functions which + use raw pointers in their signature (for instance some trait declarations + for the slices). For now, we use a dedicated type to "mark" the raw pointers, + and make sure that those functions are actually not used in the translation. + *) [@@deriving show, ord] (* TODO: we should never directly manipulate [Return] and [Fail], but rather @@ -176,6 +192,14 @@ class ['self] iter_ty_base = inherit! [_] T.iter_const_generic inherit! [_] PV.iter_literal_type method visit_type_var_id : 'env -> type_var_id -> unit = fun _ _ -> () + method visit_trait_decl_id : 'env -> trait_decl_id -> unit = fun _ _ -> () + method visit_trait_impl_id : 'env -> trait_impl_id -> unit = fun _ _ -> () + + method visit_trait_clause_id : 'env -> trait_clause_id -> unit = + fun _ _ -> () + + method visit_trait_item_name : 'env -> trait_item_name -> unit = + fun _ _ -> () end (** Ancestor for map visitor for [ty] *) @@ -185,6 +209,18 @@ class ['self] map_ty_base = inherit! [_] T.map_const_generic inherit! [_] PV.map_literal_type method visit_type_var_id : 'env -> type_var_id -> type_var_id = fun _ x -> x + + method visit_trait_decl_id : 'env -> trait_decl_id -> trait_decl_id = + fun _ x -> x + + method visit_trait_impl_id : 'env -> trait_impl_id -> trait_impl_id = + fun _ x -> x + + method visit_trait_clause_id : 'env -> trait_clause_id -> trait_clause_id = + fun _ x -> x + + method visit_trait_item_name : 'env -> trait_item_name -> trait_item_name = + fun _ x -> x end (** Ancestor for reduce visitor for [ty] *) @@ -194,6 +230,18 @@ class virtual ['self] reduce_ty_base = inherit! [_] T.reduce_const_generic inherit! [_] PV.reduce_literal_type method visit_type_var_id : 'env -> type_var_id -> 'a = fun _ _ -> self#zero + + method visit_trait_decl_id : 'env -> trait_decl_id -> 'a = + fun _ _ -> self#zero + + method visit_trait_impl_id : 'env -> trait_impl_id -> 'a = + fun _ _ -> self#zero + + method visit_trait_clause_id : 'env -> trait_clause_id -> 'a = + fun _ _ -> self#zero + + method visit_trait_item_name : 'env -> trait_item_name -> 'a = + fun _ _ -> self#zero end (** Ancestor for mapreduce visitor for [ty] *) @@ -205,10 +253,24 @@ class virtual ['self] mapreduce_ty_base = method visit_type_var_id : 'env -> type_var_id -> type_var_id * 'a = fun _ x -> (x, self#zero) + + method visit_trait_decl_id : 'env -> trait_decl_id -> trait_decl_id * 'a = + fun _ x -> (x, self#zero) + + method visit_trait_impl_id : 'env -> trait_impl_id -> trait_impl_id * 'a = + fun _ x -> (x, self#zero) + + method visit_trait_clause_id + : 'env -> trait_clause_id -> trait_clause_id * 'a = + fun _ x -> (x, self#zero) + + method visit_trait_item_name + : 'env -> trait_item_name -> trait_item_name * 'a = + fun _ x -> (x, self#zero) end type ty = - | Adt of type_id * ty list * const_generic list + | Adt of type_id * generic_args (** {!Adt} encodes ADTs and tuples and assumed types. TODO: what about the ended regions? (ADTs may be parameterized @@ -219,8 +281,38 @@ type ty = | TypeVar of type_var_id | Literal of literal_type | Arrow of ty * ty + | TraitType of trait_ref * generic_args * string + (** The string is for the name of the associated type *) + +and trait_ref = { + trait_id : trait_instance_id; + generics : generic_args; + trait_decl_ref : trait_decl_ref; +} + +and trait_decl_ref = { + trait_decl_id : trait_decl_id; + decl_generics : generic_args; (* The name: annoying field collisions... *) +} + +and generic_args = { + types : ty list; + const_generics : const_generic list; + trait_refs : trait_ref list; +} + +and trait_instance_id = + | Self + | TraitImpl of trait_impl_id + | Clause of trait_clause_id + | ParentClause of trait_instance_id * trait_decl_id * trait_clause_id + | ItemClause of + trait_instance_id * trait_decl_id * trait_item_name * trait_clause_id + | TraitRef of trait_ref + | UnknownTrait of string [@@deriving show, + ord, visitors { name = "iter_ty"; @@ -264,12 +356,37 @@ type type_decl_kind = Struct of field list | Enum of variant list | Opaque type type_var = T.type_var [@@deriving show] +type trait_clause = { + clause_id : trait_clause_id; + trait_id : trait_decl_id; + generics : generic_args; +} +[@@deriving show] + +type generic_params = { + types : type_var list; + const_generics : const_generic_var list; + trait_clauses : trait_clause list; +} +[@@deriving show] + +type trait_type_constraint = { + trait_ref : trait_ref; + generics : generic_args; + type_name : trait_item_name; + ty : ty; +} +[@@deriving show, ord] + +type predicates = { trait_type_constraints : trait_type_constraint list } +[@@deriving show] + type type_decl = { def_id : TypeDeclId.id; name : name; - type_params : type_var list; - const_generic_params : const_generic_var list; + generics : generic_params; kind : type_decl_kind; + preds : predicates; } [@@deriving show] @@ -420,8 +537,15 @@ type pure_assumed_fun_id = | FuelEqZero (** Test if some fuel is equal to 0 - TODO: ugly *) [@@deriving show, ord] +type fun_id_or_trait_method_ref = + | FunId of A.fun_id + | TraitMethod of trait_ref * string * fun_decl_id + (** The fun decl id is not really needed and here for convenience purposes *) +[@@deriving show, ord] + (** A function id for a non-assumed function *) -type regular_fun_id = A.fun_id * LoopId.id option * T.RegionGroupId.id option +type regular_fun_id = + fun_id_or_trait_method_ref * LoopId.id option * T.RegionGroupId.id option [@@deriving show, ord] (** A function identifier *) @@ -457,23 +581,20 @@ type projection = { adt_id : type_id; field_id : FieldId.id } [@@deriving show] type qualif_id = | FunOrOp of fun_or_op_id (** A function or an operation *) - | Global of GlobalDeclId.id + | Global of global_decl_id | AdtCons of adt_cons_id (** A function or ADT constructor identifier *) | Proj of projection (** Field projector *) + | TraitConst of trait_ref * generic_args * string + (** A trait associated constant *) [@@deriving show] -(** An instantiated qualified. +(** An instantiated qualifier. Note that for now we have a clear separation between types and expressions, - which explains why we have the [type_params] field: a function or ADT + which explains why we have the [generics] field: a function or ADT constructor is always fully instantiated. *) -type qualif = { - id : qualif_id; - type_args : ty list; - const_generic_args : const_generic list; -} -[@@deriving show] +type qualif = { id : qualif_id; generics : generic_args } [@@deriving show] type field_id = FieldId.id [@@deriving show, ord] type var_id = VarId.id [@@deriving show, ord] @@ -536,6 +657,7 @@ class virtual ['self] mapreduce_expression_base = *) type expression = | Var of var_id (** a variable *) + | CVar of const_generic_var_id (** a const generic var *) | Const of literal | App of texpression * texpression (** Application of a function to an argument. @@ -787,11 +909,11 @@ type fun_sig_info = { - etc. *) type fun_sig = { - type_params : type_var list; - const_generic_params : const_generic_var list; + generics : generic_params; (** TODO: we should analyse the signature to make the type parameters implicit whenever possible *) + preds : predicates; inputs : ty list; - (** The input types. + (** The types of the inputs. Note that those input types take into account the [fuel] parameter, if the function uses fuel for termination, and the [state] parameter, @@ -861,8 +983,11 @@ type fun_body = { } [@@deriving show] +type fun_kind = A.fun_kind [@@deriving show] + type fun_decl = { def_id : FunDeclId.id; + kind : fun_kind; num_loops : int; (** The number of loops in the parent forward function (basically the number of loops appearing in the original Rust functions, unless some loops are @@ -882,3 +1007,30 @@ type fun_decl = { body : fun_body option; } [@@deriving show] + +type trait_decl = { + def_id : trait_decl_id; + name : name; + generics : generic_params; + preds : predicates; + parent_clauses : trait_clause list; + consts : (trait_item_name * (ty * global_decl_id option)) list; + types : (trait_item_name * (trait_clause list * ty option)) list; + required_methods : (trait_item_name * fun_decl_id) list; + provided_methods : (trait_item_name * fun_decl_id option) list; +} +[@@deriving show] + +type trait_impl = { + def_id : trait_impl_id; + name : name; + impl_trait : trait_decl_ref; + generics : generic_params; + preds : predicates; + parent_trait_refs : trait_ref list; + consts : (trait_item_name * (ty * global_decl_id)) list; + types : (trait_item_name * (trait_ref list * ty)) list; + required_methods : (trait_item_name * fun_decl_id) list; + provided_methods : (trait_item_name * fun_decl_id) list; +} +[@@deriving show] diff --git a/compiler/PureMicroPasses.ml b/compiler/PureMicroPasses.ml index b6025df4..f3e6cbe2 100644 --- a/compiler/PureMicroPasses.ml +++ b/compiler/PureMicroPasses.ml @@ -376,8 +376,7 @@ let compute_pretty_names (def : fun_decl) : fun_decl = let ty = e.ty in let ctx, e = match e.e with - | Var _ -> (* Nothing to do *) (ctx, e.e) - | Const _ -> (* Nothing to do *) (ctx, e.e) + | Var _ | CVar _ | Const _ -> (* Nothing to do *) (ctx, e.e) | App (app, arg) -> let ctx, app = update_texpression app ctx in let ctx, arg = update_texpression arg ctx in @@ -584,13 +583,10 @@ let intro_struct_updates (ctx : trans_ctx) (def : fun_decl) : fun_decl = | Qualif { id = AdtCons { adt_id = AdtId adt_id; variant_id = None }; - type_args = _; - const_generic_args = _; + generics = _; } -> (* Lookup the def *) - let decl = - TypeDeclId.Map.find adt_id ctx.type_context.type_decls - in + let decl = TypeDeclId.Map.find adt_id ctx.type_ctx.type_decls in (* Check that there are as many arguments as there are fields - note that the def should have a body (otherwise we couldn't use the constructor) *) @@ -599,8 +595,7 @@ let intro_struct_updates (ctx : trans_ctx) (def : fun_decl) : fun_decl = (* Check if the definition is recursive *) let is_rec = match - TypeDeclId.Map.find adt_id - ctx.type_context.type_decls_groups + TypeDeclId.Map.find adt_id ctx.type_ctx.type_decls_groups with | NonRec _ -> false | Rec _ -> true @@ -682,8 +677,8 @@ let inline_useless_var_reassignments (inline_named : bool) (inline_pure : bool) | _ -> false in (* And either: - * 2.1 the right-expression is a variable or a global *) - let var_or_global = is_var re || is_global re in + * 2.1 the right-expression is a variable, a global or a const generic var *) + let var_or_global = is_var re || is_cvar re || is_global re in (* Or: * 2.2 the right-expression is a constant value, an ADT value, * a projection or a primitive function call *and* the flag @@ -767,10 +762,10 @@ let inline_useless_var_reassignments (inline_named : bool) (inline_pure : bool) In this situation, we can remove the call [f@fwd x]. *) let expression_contains_child_call_in_all_paths (ctx : trans_ctx) - (id0 : A.fun_id) (lp_id0 : LoopId.id option) - (rg_id0 : T.RegionGroupId.id option) (tys0 : ty list) + (id0 : fun_id_or_trait_method_ref) (lp_id0 : LoopId.id option) + (rg_id0 : T.RegionGroupId.id option) (generics0 : generic_args) (args0 : texpression list) (e : texpression) : bool = - let check_call (fun_id1 : fun_or_op_id) (tys1 : ty list) + let check_call (fun_id1 : fun_or_op_id) (generics1 : generic_args) (args1 : texpression list) : bool = (* Check the fun_ids, to see if call1's function is a child of call0's function *) match fun_id1 with @@ -793,7 +788,12 @@ let expression_contains_child_call_in_all_paths (ctx : trans_ctx) (* We need to use the regions hierarchy *) (* First, lookup the signature of the LLBC function *) let sg = - LlbcAstUtils.lookup_fun_sig id0 ctx.fun_context.fun_decls + let id0 = + match id0 with + | FunId fun_id -> fun_id + | TraitMethod (_, _, fun_decl_id) -> Regular fun_decl_id + in + LlbcAstUtils.lookup_fun_sig id0 ctx.fun_ctx.fun_decls in (* Compute the set of ancestors of the function in call1 *) let call1_ancestors = @@ -817,8 +817,8 @@ let expression_contains_child_call_in_all_paths (ctx : trans_ctx) let input_eq (v0, v1) = PureUtils.remove_meta v0 = PureUtils.remove_meta v1 in - (* Compare the input types and the prefix of the input arguments *) - tys0 = tys1 && List.for_all input_eq args + (* Compare the generics and the prefix of the input arguments *) + generics0 = generics1 && List.for_all input_eq args else (* Not a child *) false else (* Not the same function *) @@ -834,7 +834,7 @@ let expression_contains_child_call_in_all_paths (ctx : trans_ctx) method! visit_texpression env e = match e.e with - | Var _ | Const _ -> fun _ -> false + | Var _ | CVar _ | Const _ -> fun _ -> false | StructUpdate _ -> (* There shouldn't be monadic calls in structure updates - also note that by returning [false] we are conservative: we might @@ -844,8 +844,8 @@ let expression_contains_child_call_in_all_paths (ctx : trans_ctx) | Let (_, _, re, e) -> ( match opt_destruct_function_call re with | None -> fun () -> self#visit_texpression env e () - | Some (func1, tys1, args1) -> - let call_is_child = check_call func1 tys1 args1 in + | Some (func1, generics1, args1) -> + let call_is_child = check_call func1 generics1 args1 in if call_is_child then fun () -> true else fun () -> self#visit_texpression env e ()) | App _ -> ( @@ -930,7 +930,7 @@ let filter_useless (filter_monadic_calls : bool) (ctx : trans_ctx) method! visit_expression env e = match e with - | Var _ | Const _ | App _ | Qualif _ + | Var _ | CVar _ | Const _ | App _ | Qualif _ | Switch (_, _) | Meta (_, _) | StructUpdate _ | Abs _ -> @@ -1086,13 +1086,12 @@ let simplify_aggregates (ctx : trans_ctx) (def : fun_decl) : fun_decl = | Qualif { id = AdtCons { adt_id = AdtId adt_id; variant_id = None }; - type_args; - const_generic_args; + generics; } -> (* This is a struct *) (* Retrieve the definiton, to find how many fields there are *) let adt_decl = - TypeDeclId.Map.find adt_id ctx.type_context.type_decls + TypeDeclId.Map.find adt_id ctx.type_ctx.type_decls in let fields = match adt_decl.kind with @@ -1108,7 +1107,7 @@ let simplify_aggregates (ctx : trans_ctx) (def : fun_decl) : fun_decl = * [x.field] for some variable [x], and where the projection * is for the proper ADT *) let to_var_proj (i : int) (arg : texpression) : - (ty list * const_generic list * var_id) option = + (generic_args * var_id) option = match arg.e with | App (proj, x) -> ( match (proj.e, x.e) with @@ -1116,16 +1115,14 @@ let simplify_aggregates (ctx : trans_ctx) (def : fun_decl) : fun_decl = { id = Proj { adt_id = AdtId proj_adt_id; field_id }; - type_args = proj_type_args; - const_generic_args = proj_const_generic_args; + generics = proj_generics; }, Var v ) -> (* We check that this is the proper ADT, and the proper field *) if proj_adt_id = adt_id && FieldId.to_int field_id = i - then - Some (proj_type_args, proj_const_generic_args, v) + then Some (proj_generics, v) else None | _ -> None) | _ -> None @@ -1136,14 +1133,13 @@ let simplify_aggregates (ctx : trans_ctx) (def : fun_decl) : fun_decl = if List.length args = num_fields then (* Check that this is the same variable we project from - * note that we checked above that there is at least one field *) - let (_, _, x), end_args = Collections.List.pop args in - if List.for_all (fun (_, _, y) -> y = x) end_args then ( + let (_, x), end_args = Collections.List.pop args in + if List.for_all (fun (_, y) -> y = x) end_args then ( (* We can substitute *) (* Sanity check: all types correct *) assert ( List.for_all - (fun (tys, cgs, _) -> - tys = type_args && cgs = const_generic_args) + (fun (generics1, _) -> generics1 = generics) args); { e with e = Var x }) else super#visit_texpression env e @@ -1162,8 +1158,7 @@ let simplify_aggregates (ctx : trans_ctx) (def : fun_decl) : fun_decl = | ( Qualif { id = Proj { adt_id = AdtId proj_adt_id; field_id }; - type_args = _; - const_generic_args = _; + generics = _; }, Var v ) -> (* We check that this is the proper ADT, and the proper field *) @@ -1361,8 +1356,8 @@ let decompose_loops (def : fun_decl) : fun_decl * fun_decl list = let loop_sig = { - type_params = fun_sig.type_params; - const_generic_params = fun_sig.const_generic_params; + generics = fun_sig.generics; + preds = fun_sig.preds; inputs = inputs_tys; output; doutputs; @@ -1427,6 +1422,7 @@ let decompose_loops (def : fun_decl) : fun_decl * fun_decl list = let loop_def = { def_id = def.def_id; + kind = def.kind; num_loops; loop_id = Some loop.loop_id; back_id = def.back_id; @@ -1466,13 +1462,12 @@ let decompose_loops (def : fun_decl) : fun_decl * fun_decl list = In such situation, we can remove the forward function definition altogether. *) -let keep_forward (trans : pure_fun_translation) : bool = - let (fwd, _), backs = trans in +let keep_forward (fwd : fun_and_loops) (backs : fun_and_loops list) : bool = (* Note that at this point, the output types are no longer seen as tuples: * they should be lists of length 1. *) if !Config.filter_useless_functions - && fwd.signature.output = mk_result_ty mk_unit_ty + && fwd.f.signature.output = mk_result_ty mk_unit_ty && backs <> [] then false else true @@ -1518,7 +1513,7 @@ let unit_vars_to_unit (def : fun_decl) : fun_decl = function calls, and when translating end abstractions. Here, we can do something simpler, in one micro-pass. *) -let eliminate_box_functions (_ctx : trans_ctx) (def : fun_decl) : fun_decl = +let eliminate_box_functions (ctx : trans_ctx) (def : fun_decl) : fun_decl = (* The map visitor *) let obj = object @@ -1527,30 +1522,44 @@ let eliminate_box_functions (_ctx : trans_ctx) (def : fun_decl) : fun_decl = method! visit_texpression env e = match opt_destruct_function_call e with | Some (fun_id, _tys, args) -> ( + (* Below, when dealing with the arguments: we consider the very + * general case, where functions could be boxed (meaning we + * could have: [box_new f x]) + * *) match fun_id with - | Fun (FromLlbc (A.Assumed aid, _lp_id, rg_id)) -> ( - (* Below, when dealing with the arguments: we consider the very - * general case, where functions could be boxed (meaning we - * could have: [box_new f x]) - * *) + | Fun (FromLlbc (FunId (Assumed aid), _lp_id, rg_id)) -> ( match (aid, rg_id) with - | A.BoxNew, _ -> + | BoxNew, _ -> assert (rg_id = None); let arg, args = Collections.List.pop args in mk_apps arg args - | A.BoxDeref, None -> + | BoxFree, _ -> + assert (args = []); + mk_unit_rvalue + | ( ( SliceIndexShared | SliceIndexMut | ArrayIndexShared + | ArrayIndexMut | ArrayToSliceShared | ArrayToSliceMut + | ArrayRepeat | SliceLen ), + _ ) -> + super#visit_texpression env e) + | Fun (FromLlbc (FunId (Regular fid), _lp_id, rg_id)) -> ( + (* Lookup the function name *) + let def = FunDeclId.Map.find fid ctx.fun_ctx.fun_decls in + match + (Names.name_no_disambiguators_to_string def.name, rg_id) + with + | "alloc::boxed::Box::deref", None -> (* [Box::deref] forward is the identity *) let arg, args = Collections.List.pop args in mk_apps arg args - | A.BoxDeref, Some _ -> + | "alloc::boxed::Box::deref", Some _ -> (* [Box::deref] backward is [()] (doesn't give back anything) *) assert (args = []); mk_unit_rvalue - | A.BoxDerefMut, None -> + | "alloc::boxed::Box::deref_mut", None -> (* [Box::deref_mut] forward is the identity *) let arg, args = Collections.List.pop args in mk_apps arg args - | A.BoxDerefMut, Some _ -> + | "alloc::boxed::Box::deref_mut", Some _ -> (* [Box::deref_mut] back is almost the identity: * let box_deref_mut (x_init : t) (x_back : t) : t = x_back * *) @@ -1560,17 +1569,7 @@ let eliminate_box_functions (_ctx : trans_ctx) (def : fun_decl) : fun_decl = | _ -> raise (Failure "Unreachable") in mk_apps arg args - | A.BoxFree, _ -> - assert (args = []); - mk_unit_rvalue - | ( ( A.Replace | VecNew | VecPush | VecInsert | VecLen - | VecIndex | VecIndexMut | ArraySubsliceShared - | ArraySubsliceMut | SliceIndexShared | SliceIndexMut - | SliceSubsliceShared | SliceSubsliceMut | ArrayIndexShared - | ArrayIndexMut | ArrayToSliceShared | ArrayToSliceMut - | SliceLen ), - _ ) -> - super#visit_texpression env e) + | _ -> super#visit_texpression env e) | _ -> super#visit_texpression env e) | _ -> super#visit_texpression env e end @@ -1914,7 +1913,7 @@ let apply_end_passes_to_def (ctx : trans_ctx) (def : fun_decl) : fun_decl = [ctx]: used only for printing. *) let apply_passes_to_def (ctx : trans_ctx) (def : fun_decl) : - (fun_decl * fun_decl list) option = + fun_and_loops option = (* Debug *) log#ldebug (lazy @@ -1955,9 +1954,9 @@ let apply_passes_to_def (ctx : trans_ctx) (def : fun_decl) : let def, loops = decompose_loops def in (* Apply the remaining passes *) - let def = apply_end_passes_to_def ctx def in + let f = apply_end_passes_to_def ctx def in let loops = List.map (apply_end_passes_to_def ctx) loops in - Some (def, loops) + Some { f; loops } (** Small utility for {!filter_loop_inputs} *) let filter_prefix (keep : bool list) (ls : 'a list) : 'a list = @@ -1983,8 +1982,8 @@ end module FunLoopIdMap = Collections.MakeMap (FunLoopIdOrderedType) (** Filter the useless loop input parameters. *) -let filter_loop_inputs (transl : (bool * pure_fun_translation) list) : - (bool * pure_fun_translation) list = +let filter_loop_inputs (transl : pure_fun_translation list) : + pure_fun_translation list = (* We need to explore groups of mutually recursive functions. In order to compute which parameters are useless, we need to explore the functions by groups of mutually recursive definitions. @@ -2002,10 +2001,11 @@ let filter_loop_inputs (transl : (bool * pure_fun_translation) list) : (List.concat (List.concat (List.map - (fun (_, ((fwd, loops_fwd), backs)) -> - [ fwd :: loops_fwd ] + (fun { fwd; backs; _ } -> + [ fwd.f :: fwd.loops ] :: List.map - (fun (back, loops_back) -> [ back :: loops_back ]) + (fun { f = back; loops = loops_back } -> + [ back :: loops_back ]) backs) transl))) in @@ -2030,7 +2030,6 @@ let filter_loop_inputs (transl : (bool * pure_fun_translation) list) : additional parameters. *) let used_map = ref FunLoopIdMap.empty in - let fun_id_to_fun_loop_id (fid, loop_id, _) = (fid, loop_id) in (* We start by computing the filtering information, for each function *) let compute_one_filter_info (decl : fun_decl) = @@ -2051,7 +2050,7 @@ let filter_loop_inputs (transl : (bool * pure_fun_translation) list) : let inputs_set = VarId.Set.of_list (List.map var_get_id inputs_prefix) in assert (Option.is_some decl.loop_id); - let fun_id = (A.Regular decl.def_id, decl.loop_id) in + let fun_id = (E.Regular decl.def_id, decl.loop_id) in let set_used vid = used := List.map (fun (vid', b) -> (vid', b || vid = vid')) !used @@ -2075,8 +2074,8 @@ let filter_loop_inputs (transl : (bool * pure_fun_translation) list) : match e_app.e with | Qualif qualif -> ( match qualif.id with - | FunOrOp (Fun (FromLlbc fun_id')) -> - if fun_id_to_fun_loop_id fun_id' = fun_id then ( + | FunOrOp (Fun (FromLlbc (FunId fun_id', loop_id', _))) -> + if (fun_id', loop_id') = fun_id then ( (* For each argument, check if it is exactly the original input parameter. Note that there shouldn't be partial applications of loop functions: the number of arguments @@ -2135,22 +2134,15 @@ let filter_loop_inputs (transl : (bool * pure_fun_translation) list) : (* We then apply the filtering to all the function definitions at once *) let filter_in_one (decl : fun_decl) : fun_decl = (* Filter the function signature *) - let fun_id = (A.Regular decl.def_id, decl.loop_id, decl.back_id) in + let fun_id = (E.Regular decl.def_id, decl.loop_id) in let decl = - match FunLoopIdMap.find_opt (fun_id_to_fun_loop_id fun_id) !used_map with + match FunLoopIdMap.find_opt fun_id !used_map with | None -> (* Nothing to filter *) decl | Some used_info -> let num_filtered = List.length (List.filter (fun b -> not b) used_info) in - let { - type_params; - const_generic_params; - inputs; - output; - doutputs; - info; - } = + let { generics; preds; inputs; output; doutputs; info } = decl.signature in let { @@ -2178,16 +2170,7 @@ let filter_loop_inputs (transl : (bool * pure_fun_translation) list) : effect_info; } in - let signature = - { - type_params; - const_generic_params; - inputs; - output; - doutputs; - info; - } - in + let signature = { generics; preds; inputs; output; doutputs; info } in { decl with signature } in @@ -2201,9 +2184,7 @@ let filter_loop_inputs (transl : (bool * pure_fun_translation) list) : let { inputs; inputs_lvs; body } = body in let inputs, inputs_lvs = - match - FunLoopIdMap.find_opt (fun_id_to_fun_loop_id fun_id) !used_map - with + match FunLoopIdMap.find_opt fun_id !used_map with | None -> (* Nothing to filter *) (inputs, inputs_lvs) | Some used_info -> let inputs = filter_prefix used_info inputs in @@ -2223,11 +2204,10 @@ let filter_loop_inputs (transl : (bool * pure_fun_translation) list) : match e_app.e with | Qualif qualif -> ( match qualif.id with - | FunOrOp (Fun (FromLlbc fun_id)) -> ( + | FunOrOp (Fun (FromLlbc (FunId fun_id, loop_id, _))) + -> ( match - FunLoopIdMap.find_opt - (fun_id_to_fun_loop_id fun_id) - !used_map + FunLoopIdMap.find_opt (fun_id, loop_id) !used_map with | None -> super#visit_texpression env e | Some used_info -> @@ -2267,13 +2247,13 @@ let filter_loop_inputs (transl : (bool * pure_fun_translation) list) : in let transl = List.map - (fun (b, (fwd, backs)) -> - let filter_fun_and_loops (f, fl) = - (filter_in_one f, List.map filter_in_one fl) + (fun trans -> + let filter_fun_and_loops f = + { f = filter_in_one f.f; loops = List.map filter_in_one f.loops } in - let fwd = filter_fun_and_loops fwd in - let backs = List.map filter_fun_and_loops backs in - (b, (fwd, backs))) + let fwd = filter_fun_and_loops trans.fwd in + let backs = List.map filter_fun_and_loops trans.backs in + { trans with fwd; backs }) transl in @@ -2294,18 +2274,17 @@ let filter_loop_inputs (transl : (bool * pure_fun_translation) list) : but convenient. *) let apply_passes_to_pure_fun_translations (ctx : trans_ctx) - (transl : (fun_decl * fun_decl list) list) : - (bool * pure_fun_translation) list = - let apply_to_one (trans : fun_decl * fun_decl list) : - bool * pure_fun_translation = + (transl : (fun_decl * fun_decl list) list) : pure_fun_translation list = + let apply_to_one (trans : fun_decl * fun_decl list) : pure_fun_translation = (* Apply the passes to the individual functions *) - let forward, backwards = trans in - let forward = Option.get (apply_passes_to_def ctx forward) in - let backwards = List.filter_map (apply_passes_to_def ctx) backwards in - let trans = (forward, backwards) in + let fwd, backs = trans in + let fwd = Option.get (apply_passes_to_def ctx fwd) in + let backs = List.filter_map (apply_passes_to_def ctx) backs in (* Compute whether we need to filter the forward function or not *) - (keep_forward trans, trans) + let keep_fwd = keep_forward fwd backs in + { keep_fwd; fwd; backs } in + let transl = List.map apply_to_one transl in (* Filter the useless inputs in the loop functions *) diff --git a/compiler/PureTypeCheck.ml b/compiler/PureTypeCheck.ml index 8d28bb8a..2ad942bb 100644 --- a/compiler/PureTypeCheck.ml +++ b/compiler/PureTypeCheck.ml @@ -9,17 +9,19 @@ open PureUtils of fields is fixed: it shouldn't be used for arrays, slices, etc. *) let get_adt_field_types (type_decls : type_decl TypeDeclId.Map.t) - (type_id : type_id) (variant_id : VariantId.id option) (tys : ty list) - (cgs : const_generic list) : ty list = + (type_id : type_id) (variant_id : VariantId.id option) + (generics : generic_args) : ty list = match type_id with | Tuple -> (* Tuple *) + assert (generics.const_generics = []); + assert (generics.trait_refs = []); assert (variant_id = None); - tys + generics.types | AdtId def_id -> (* "Regular" ADT *) let def = TypeDeclId.Map.find def_id type_decls in - type_decl_get_instantiated_fields_types def variant_id tys cgs + type_decl_get_instantiated_fields_types def variant_id generics | Assumed aty -> ( (* Assumed type *) match aty with @@ -27,14 +29,14 @@ let get_adt_field_types (type_decls : type_decl TypeDeclId.Map.t) (* This type is opaque *) raise (Failure "Unreachable: opaque type") | Result -> - let ty = Collections.List.to_cons_nil tys in + let ty = Collections.List.to_cons_nil generics.types in let variant_id = Option.get variant_id in if variant_id = result_return_id then [ ty ] else if variant_id = result_fail_id then [ mk_error_ty ] else raise (Failure "Unreachable: improper variant id for result type") | Error -> - assert (tys = []); + assert (generics = empty_generic_args); let variant_id = Option.get variant_id in assert ( variant_id = error_failure_id || variant_id = error_out_of_fuel_id); @@ -44,18 +46,7 @@ let get_adt_field_types (type_decls : type_decl TypeDeclId.Map.t) if variant_id = fuel_zero_id then [] else if variant_id = fuel_succ_id then [ mk_fuel_ty ] else raise (Failure "Unreachable: improper variant id for fuel type") - | Option -> - let ty = Collections.List.to_cons_nil tys in - let variant_id = Option.get variant_id in - if variant_id = option_some_id then [ ty ] - else if variant_id = option_none_id then [] - else - raise (Failure "Unreachable: improper variant id for option type") - | Range -> - let ty = Collections.List.to_cons_nil tys in - assert (variant_id = None); - [ ty; ty ] - | Vec | Array | Slice | Str -> + | Array | Slice | Str | RawPtr _ -> (* Array: when not symbolic values (for instance, because of aggregates), the array expressions are introduced as struct updates *) raise (Failure "Attempting to access the fields of an opaque type")) @@ -65,6 +56,9 @@ type tc_ctx = { global_decls : A.global_decl A.GlobalDeclId.Map.t; (** The global declarations *) env : ty VarId.Map.t; (** Environment from variables to types *) + const_generics : ty T.ConstGenericVarId.Map.t; + (** The types of the const generics *) + (* TODO: add trait type constraints *) } let check_literal (v : literal) (ty : literal_type) : unit = @@ -86,12 +80,13 @@ let rec check_typed_pattern (ctx : tc_ctx) (v : typed_pattern) : tc_ctx = { ctx with env } | PatAdt av -> (* Compute the field types *) - let type_id, tys, cgs = ty_as_adt v.ty in + let type_id, generics = ty_as_adt v.ty in let field_tys = - get_adt_field_types ctx.type_decls type_id av.variant_id tys cgs + get_adt_field_types ctx.type_decls type_id av.variant_id generics in let check_value (ctx : tc_ctx) (ty : ty) (v : typed_pattern) : tc_ctx = if ty <> v.ty then ( + (* TODO: we need to normalize the types *) log#serror ("check_typed_pattern: not the same types:" ^ "\n- ty: " ^ show_ty ty ^ "\n- v.ty: " ^ show_ty v.ty); @@ -115,6 +110,9 @@ let rec check_texpression (ctx : tc_ctx) (e : texpression) : unit = match VarId.Map.find_opt var_id ctx.env with | None -> () | Some ty -> assert (ty = e.ty)) + | CVar cg_id -> + let ty = T.ConstGenericVarId.Map.find cg_id ctx.const_generics in + assert (ty = e.ty) | Const cv -> check_literal cv (ty_as_literal e.ty) | App (app, arg) -> let input_ty, output_ty = destruct_arrow app.ty in @@ -133,35 +131,34 @@ let rec check_texpression (ctx : tc_ctx) (e : texpression) : unit = match qualif.id with | FunOrOp _ -> () (* TODO *) | Global _ -> () (* TODO *) + | TraitConst _ -> () (* TODO *) | Proj { adt_id = proj_adt_id; field_id } -> (* Note we can only project fields of structures (not enumerations) *) (* Deconstruct the projector type *) let adt_ty, field_ty = destruct_arrow e.ty in - let adt_id, adt_type_args, adt_cg_args = ty_as_adt adt_ty in + let adt_id, adt_generics = ty_as_adt adt_ty in (* Check the ADT type *) assert (adt_id = proj_adt_id); - assert (adt_type_args = qualif.type_args); - assert (adt_cg_args = qualif.const_generic_args); + assert (adt_generics = qualif.generics); (* Retrieve and check the expected field type *) let variant_id = None in let expected_field_tys = get_adt_field_types ctx.type_decls proj_adt_id variant_id - qualif.type_args qualif.const_generic_args + qualif.generics in let expected_field_ty = FieldId.nth expected_field_tys field_id in assert (expected_field_ty = field_ty) | AdtCons id -> ( let expected_field_tys = get_adt_field_types ctx.type_decls id.adt_id id.variant_id - qualif.type_args qualif.const_generic_args + qualif.generics in let field_tys, adt_ty = destruct_arrows e.ty in assert (expected_field_tys = field_tys); match adt_ty with - | Adt (type_id, tys, cgs) -> + | Adt (type_id, generics) -> assert (type_id = id.adt_id); - assert (tys = qualif.type_args); - assert (cgs = qualif.const_generic_args) + assert (generics = qualif.generics) | _ -> raise (Failure "Unreachable"))) | Let (monadic, pat, re, e_next) -> let expected_pat_ty = if monadic then destruct_result re.ty else re.ty in @@ -207,15 +204,14 @@ let rec check_texpression (ctx : tc_ctx) (e : texpression) : unit = | Some ty -> assert (ty = e.ty)); (* Check the fields *) (* Retrieve and check the expected field type *) - let adt_id, adt_type_args, adt_cg_args = ty_as_adt e.ty in + let adt_id, adt_generics = ty_as_adt e.ty in assert (adt_id = supd.struct_id); (* The id can only be: a custom type decl or an array *) match adt_id with | AdtId _ -> let variant_id = None in let expected_field_tys = - get_adt_field_types ctx.type_decls adt_id variant_id adt_type_args - adt_cg_args + get_adt_field_types ctx.type_decls adt_id variant_id adt_generics in List.iter (fun (fid, fe) -> @@ -224,7 +220,9 @@ let rec check_texpression (ctx : tc_ctx) (e : texpression) : unit = check_texpression ctx fe) supd.updates | Assumed Array -> - let expected_field_ty = Collections.List.to_cons_nil adt_type_args in + let expected_field_ty = + Collections.List.to_cons_nil adt_generics.types + in List.iter (fun (_, fe) -> assert (expected_field_ty = fe.ty); diff --git a/compiler/PureUtils.ml b/compiler/PureUtils.ml index 1c8d8921..3aeabffe 100644 --- a/compiler/PureUtils.ml +++ b/compiler/PureUtils.ml @@ -89,14 +89,31 @@ let mk_mplace (var_id : E.VarId.id) (name : string option) (projection : mprojection) : mplace = { var_id; name; projection } +let empty_generic_params : generic_params = + { types = []; const_generics = []; trait_clauses = [] } + +let empty_generic_args : generic_args = + { types = []; const_generics = []; trait_refs = [] } + +let mk_generic_args_from_types (types : ty list) : generic_args = + { types; const_generics = []; trait_refs = [] } + +type subst = { + ty_subst : TypeVarId.id -> ty; + cg_subst : ConstGenericVarId.id -> const_generic; + tr_subst : TraitClauseId.id -> trait_instance_id; + tr_self : trait_instance_id; +} + (** Type substitution *) -let ty_substitute (tsubst : TypeVarId.id -> ty) - (cgsubst : ConstGenericVarId.id -> const_generic) (ty : ty) : ty = +let ty_substitute (subst : subst) (ty : ty) : ty = let obj = object inherit [_] map_ty - method! visit_TypeVar _ var_id = tsubst var_id - method! visit_ConstGenericVar _ var_id = cgsubst var_id + method! visit_TypeVar _ var_id = subst.ty_subst var_id + method! visit_ConstGenericVar _ var_id = subst.cg_subst var_id + method! visit_Clause _ id = subst.tr_subst id + method! visit_Self _ = subst.tr_self end in obj#visit_ty () ty @@ -115,6 +132,18 @@ let make_const_generic_subst (vars : const_generic_var list) (cgs : const_generic list) : ConstGenericVarId.id -> const_generic = Substitute.make_const_generic_subst_from_vars vars cgs +let make_trait_subst (clauses : trait_clause list) (refs : trait_ref list) : + TraitClauseId.id -> trait_instance_id = + let clauses = List.map (fun x -> x.clause_id) clauses in + let refs = List.map (fun x -> TraitRef x) refs in + let ls = List.combine clauses refs in + let mp = + List.fold_left + (fun mp (k, v) -> TraitClauseId.Map.add k v mp) + TraitClauseId.Map.empty ls + in + fun id -> TraitClauseId.Map.find id mp + (** Retrieve the list of fields for the given variant of a {!type:Aeneas.Pure.type_decl}. Raises [Invalid_argument] if the arguments are incorrect. @@ -135,20 +164,27 @@ let type_decl_get_fields (def : type_decl) - def: " ^ show_type_decl def ^ "\n- opt_variant_id: " ^ opt_variant_id)) +let make_subst_from_generics (params : generic_params) (args : generic_args) + (tr_self : trait_instance_id) : subst = + let ty_subst = make_type_subst params.types args.types in + let cg_subst = + make_const_generic_subst params.const_generics args.const_generics + in + let tr_subst = make_trait_subst params.trait_clauses args.trait_refs in + { ty_subst; cg_subst; tr_subst; tr_self } + (** Instantiate the type variables for the chosen variant in an ADT definition, and return the list of the types of its fields *) let type_decl_get_instantiated_fields_types (def : type_decl) - (opt_variant_id : VariantId.id option) (types : ty list) - (cgs : const_generic list) : ty list = - let ty_subst = make_type_subst def.type_params types in - let cg_subst = make_const_generic_subst def.const_generic_params cgs in + (opt_variant_id : VariantId.id option) (generics : generic_args) : ty list = + (* There shouldn't be any reference to Self *) + let tr_self = UnknownTrait __FUNCTION__ in + let subst = make_subst_from_generics def.generics generics tr_self in let fields = type_decl_get_fields def opt_variant_id in - List.map (fun f -> ty_substitute ty_subst cg_subst f.field_ty) fields + List.map (fun f -> ty_substitute subst f.field_ty) fields -let fun_sig_substitute (tsubst : TypeVarId.id -> ty) - (cgsubst : ConstGenericVarId.id -> const_generic) (sg : fun_sig) : - inst_fun_sig = - let subst = ty_substitute tsubst cgsubst in +let fun_sig_substitute (subst : subst) (sg : fun_sig) : inst_fun_sig = + let subst = ty_substitute subst in let inputs = List.map subst sg.inputs in let output = subst sg.output in let doutputs = List.map subst sg.doutputs in @@ -164,7 +200,8 @@ let fun_sig_substitute (tsubst : TypeVarId.id -> ty) *) let rec let_group_requires_parentheses (e : texpression) : bool = match e.e with - | Var _ | Const _ | App _ | Abs _ | Qualif _ | StructUpdate _ -> false + | Var _ | CVar _ | Const _ | App _ | Abs _ | Qualif _ | StructUpdate _ -> + false | Let (monadic, _, _, next_e) -> if monadic then true else let_group_requires_parentheses next_e | Switch (_, _) -> false @@ -184,15 +221,18 @@ let is_var (e : texpression) : bool = let as_var (e : texpression) : VarId.id = match e.e with Var v -> v | _ -> raise (Failure "Unreachable") +let is_cvar (e : texpression) : bool = + match e.e with CVar _ -> true | _ -> false + let is_global (e : texpression) : bool = match e.e with Qualif { id = Global _; _ } -> true | _ -> false let is_const (e : texpression) : bool = match e.e with Const _ -> true | _ -> false -let ty_as_adt (ty : ty) : type_id * ty list * const_generic list = +let ty_as_adt (ty : ty) : type_id * generic_args = match ty with - | Adt (id, tys, cgs) -> (id, tys, cgs) + | Adt (id, generics) -> (id, generics) | _ -> raise (Failure "Unreachable") (** Remove the external occurrences of {!Meta} *) @@ -290,28 +330,30 @@ let destruct_qualif_app (e : texpression) : qualif * texpression list = (** Destruct an expression into a function call, if possible *) let opt_destruct_function_call (e : texpression) : - (fun_or_op_id * ty list * texpression list) option = + (fun_or_op_id * generic_args * texpression list) option = match opt_destruct_qualif_app e with | None -> None | Some (qualif, args) -> ( match qualif.id with - | FunOrOp fun_id -> Some (fun_id, qualif.type_args, args) + | FunOrOp fun_id -> Some (fun_id, qualif.generics, args) | _ -> None) let opt_destruct_result (ty : ty) : ty option = match ty with - | Adt (Assumed Result, tys, cgs) -> - assert (cgs = []); - Some (Collections.List.to_cons_nil tys) + | Adt (Assumed Result, generics) -> + assert (generics.const_generics = []); + assert (generics.trait_refs = []); + Some (Collections.List.to_cons_nil generics.types) | _ -> None let destruct_result (ty : ty) : ty = Option.get (opt_destruct_result ty) let opt_destruct_tuple (ty : ty) : ty list option = match ty with - | Adt (Tuple, tys, cgs) -> - assert (cgs = []); - Some tys + | Adt (Tuple, generics) -> + assert (generics.const_generics = []); + assert (generics.trait_refs = []); + Some generics.types | _ -> None let mk_abs (x : typed_pattern) (e : texpression) : texpression = @@ -383,14 +425,16 @@ let mk_switch (scrut : texpression) (sb : switch_body) : texpression = - if there is > one type: wrap them in a tuple *) let mk_simpl_tuple_ty (tys : ty list) : ty = - match tys with [ ty ] -> ty | _ -> Adt (Tuple, tys, []) + match tys with + | [ ty ] -> ty + | _ -> Adt (Tuple, mk_generic_args_from_types tys) let mk_bool_ty : ty = Literal Bool -let mk_unit_ty : ty = Adt (Tuple, [], []) +let mk_unit_ty : ty = Adt (Tuple, empty_generic_args) let mk_unit_rvalue : texpression = let id = AdtCons { adt_id = Tuple; variant_id = None } in - let qualif = { id; type_args = []; const_generic_args = [] } in + let qualif = { id; generics = empty_generic_args } in let e = Qualif qualif in let ty = mk_unit_ty in { e; ty } @@ -430,7 +474,7 @@ let mk_simpl_tuple_pattern (vl : typed_pattern list) : typed_pattern = | [ v ] -> v | _ -> let tys = List.map (fun (v : typed_pattern) -> v.ty) vl in - let ty = Adt (Tuple, tys, []) in + let ty = Adt (Tuple, mk_generic_args_from_types tys) in let value = PatAdt { variant_id = None; field_values = vl } in { value; ty } @@ -441,11 +485,11 @@ let mk_simpl_tuple_texpression (vl : texpression list) : texpression = | _ -> (* Compute the types of the fields, and the type of the tuple constructor *) let tys = List.map (fun (v : texpression) -> v.ty) vl in - let ty = Adt (Tuple, tys, []) in + let ty = Adt (Tuple, mk_generic_args_from_types tys) in let ty = mk_arrows tys ty in (* Construct the tuple constructor qualifier *) let id = AdtCons { adt_id = Tuple; variant_id = None } in - let qualif = { id; type_args = tys; const_generic_args = [] } in + let qualif = { id; generics = mk_generic_args_from_types tys } in (* Put everything together *) let cons = { e = Qualif qualif; ty } in mk_apps cons vl @@ -463,32 +507,36 @@ let ty_as_integer (t : ty) : T.integer_type = let ty_as_literal (t : ty) : T.literal_type = match t with Literal ty -> ty | _ -> raise (Failure "Unreachable") -let mk_state_ty : ty = Adt (Assumed State, [], []) -let mk_result_ty (ty : ty) : ty = Adt (Assumed Result, [ ty ], []) -let mk_error_ty : ty = Adt (Assumed Error, [], []) -let mk_fuel_ty : ty = Adt (Assumed Fuel, [], []) +let mk_state_ty : ty = Adt (Assumed State, empty_generic_args) + +let mk_result_ty (ty : ty) : ty = + Adt (Assumed Result, mk_generic_args_from_types [ ty ]) + +let mk_error_ty : ty = Adt (Assumed Error, empty_generic_args) +let mk_fuel_ty : ty = Adt (Assumed Fuel, empty_generic_args) let mk_error (error : VariantId.id) : texpression = let ty = mk_error_ty in let id = AdtCons { adt_id = Assumed Error; variant_id = Some error } in - let qualif = { id; type_args = []; const_generic_args = [] } in + let qualif = { id; generics = empty_generic_args } in let e = Qualif qualif in { e; ty } let unwrap_result_ty (ty : ty) : ty = match ty with - | Adt (Assumed Result, [ ty ], cgs) -> - assert (cgs = []); + | Adt + (Assumed Result, { types = [ ty ]; const_generics = []; trait_refs = [] }) + -> ty | _ -> raise (Failure "not a result type") let mk_result_fail_texpression (error : texpression) (ty : ty) : texpression = let type_args = [ ty ] in - let ty = Adt (Assumed Result, type_args, []) in + let ty = Adt (Assumed Result, mk_generic_args_from_types type_args) in let id = AdtCons { adt_id = Assumed Result; variant_id = Some result_fail_id } in - let qualif = { id; type_args; const_generic_args = [] } in + let qualif = { id; generics = mk_generic_args_from_types type_args } in let cons_e = Qualif qualif in let cons_ty = mk_arrow error.ty ty in let cons = { e = cons_e; ty = cons_ty } in @@ -501,11 +549,11 @@ let mk_result_fail_texpression_with_error_id (error : VariantId.id) (ty : ty) : let mk_result_return_texpression (v : texpression) : texpression = let type_args = [ v.ty ] in - let ty = Adt (Assumed Result, type_args, []) in + let ty = Adt (Assumed Result, mk_generic_args_from_types type_args) in let id = AdtCons { adt_id = Assumed Result; variant_id = Some result_return_id } in - let qualif = { id; type_args; const_generic_args = [] } in + let qualif = { id; generics = mk_generic_args_from_types type_args } in let cons_e = Qualif qualif in let cons_ty = mk_arrow v.ty ty in let cons = { e = cons_e; ty = cons_ty } in @@ -514,7 +562,7 @@ let mk_result_return_texpression (v : texpression) : texpression = (** Create a [Fail err] pattern which captures the error *) let mk_result_fail_pattern (error_pat : pattern) (ty : ty) : typed_pattern = let error_pat : typed_pattern = { value = error_pat; ty = mk_error_ty } in - let ty = Adt (Assumed Result, [ ty ], []) in + let ty = Adt (Assumed Result, mk_generic_args_from_types [ ty ]) in let value = PatAdt { variant_id = Some result_fail_id; field_values = [ error_pat ] } in @@ -526,7 +574,7 @@ let mk_result_fail_pattern_ignore_error (ty : ty) : typed_pattern = mk_result_fail_pattern error_pat ty let mk_result_return_pattern (v : typed_pattern) : typed_pattern = - let ty = Adt (Assumed Result, [ v.ty ], []) in + let ty = Adt (Assumed Result, mk_generic_args_from_types [ v.ty ]) in let value = PatAdt { variant_id = Some result_return_id; field_values = [ v ] } in @@ -561,11 +609,11 @@ let rec typed_pattern_to_texpression (pat : typed_pattern) : texpression option let fields_values = List.map (fun e -> Option.get e) fields in (* Retrieve the type id and the type args from the pat type (simpler this way *) - let adt_id, type_args, const_generic_args = ty_as_adt pat.ty in + let adt_id, generics = ty_as_adt pat.ty in (* Create the constructor *) let qualif_id = AdtCons { adt_id; variant_id = av.variant_id } in - let qualif = { id = qualif_id; type_args; const_generic_args } in + let qualif = { id = qualif_id; generics } in let cons_e = Qualif qualif in let field_tys = List.map (fun (v : texpression) -> v.ty) fields_values @@ -577,3 +625,55 @@ let rec typed_pattern_to_texpression (pat : typed_pattern) : texpression option Some (mk_apps cons fields_values).e in match e_opt with None -> None | Some e -> Some { e; ty = pat.ty } + +type trait_decl_method_decl_id = { is_provided : bool; id : fun_decl_id } + +let trait_decl_get_method (trait_decl : trait_decl) (method_name : string) : + trait_decl_method_decl_id = + (* First look in the required methods *) + let method_id = + List.find_opt (fun (s, _) -> s = method_name) trait_decl.required_methods + in + match method_id with + | Some (_, id) -> { is_provided = false; id } + | None -> + (* Must be a provided method *) + let _, id = + List.find (fun (s, _) -> s = method_name) trait_decl.provided_methods + in + { is_provided = true; id = Option.get id } + +let trait_decl_is_empty (trait_decl : trait_decl) : bool = + let { + def_id = _; + name = _; + generics = _; + preds = _; + parent_clauses; + consts; + types; + required_methods; + provided_methods; + } = + trait_decl + in + parent_clauses = [] && consts = [] && types = [] && required_methods = [] + && provided_methods = [] + +let trait_impl_is_empty (trait_impl : trait_impl) : bool = + let { + def_id = _; + name = _; + impl_trait = _; + generics = _; + preds = _; + parent_trait_refs; + consts; + types; + required_methods; + provided_methods; + } = + trait_impl + in + parent_trait_refs = [] && consts = [] && types = [] && required_methods = [] + && provided_methods = [] diff --git a/compiler/ReorderDecls.ml b/compiler/ReorderDecls.ml index fc4744bc..10b68da3 100644 --- a/compiler/ReorderDecls.ml +++ b/compiler/ReorderDecls.ml @@ -38,14 +38,16 @@ let compute_body_fun_deps (e : texpression) : FunIdSet.t = method! visit_qualif _ id = match id.id with - | FunOrOp (Unop _ | Binop _) | Global _ | AdtCons _ | Proj _ -> () + | FunOrOp (Unop _ | Binop _) + | Global _ | AdtCons _ | Proj _ | TraitConst _ -> + () | FunOrOp (Fun fid) -> ( match fid with | Pure _ -> () | FromLlbc (fid, lp_id, rg_id) -> ( match fid with - | Assumed _ -> () - | Regular fid -> + | FunId (Assumed _) -> () + | TraitMethod (_, _, fid) | FunId (Regular fid) -> let id = { def_id = fid; lp_id; rg_id } in ids := FunIdSet.add id !ids)) end diff --git a/compiler/Substitute.ml b/compiler/Substitute.ml index 38850243..23f618e2 100644 --- a/compiler/Substitute.ml +++ b/compiler/Substitute.ml @@ -9,51 +9,70 @@ module E = Expressions module A = LlbcAst module C = Contexts -(** Substitute types variables and regions in a type. *) -let ty_substitute (rsubst : 'r1 -> 'r2) (tsubst : T.TypeVarId.id -> 'r2 T.ty) - (cgsubst : T.ConstGenericVarId.id -> T.const_generic) (ty : 'r1 T.ty) : - 'r2 T.ty = - let open T in - let visitor = - object - inherit [_] map_ty - method visit_'r _ r = rsubst r - method! visit_TypeVar _ id = tsubst id +type ('r1, 'r2) subst = { + r_subst : 'r1 -> 'r2; + ty_subst : T.TypeVarId.id -> 'r2 T.ty; + cg_subst : T.ConstGenericVarId.id -> T.const_generic; + (** Substitution from *local* trait clause to trait instance *) + tr_subst : T.TraitClauseId.id -> 'r2 T.trait_instance_id; + (** Substitution for the [Self] trait instance *) + tr_self : 'r2 T.trait_instance_id; +} + +let ty_substitute_visitor (subst : ('r1, 'r2) subst) = + object + inherit [_] T.map_ty + method visit_'r _ r = subst.r_subst r + method! visit_TypeVar _ id = subst.ty_subst id - method! visit_type_var_id _ _ = - (* We should never get here because we reimplemented [visit_TypeVar] *) - raise (Failure "Unexpected") + method! visit_type_var_id _ _ = + (* We should never get here because we reimplemented [visit_TypeVar] *) + raise (Failure "Unexpected") - method! visit_ConstGenericVar _ id = cgsubst id + method! visit_ConstGenericVar _ id = subst.cg_subst id - method! visit_const_generic_var_id _ _ = - (* We should never get here because we reimplemented [visit_Var] *) - raise (Failure "Unexpected") - end - in + method! visit_const_generic_var_id _ _ = + (* We should never get here because we reimplemented [visit_Var] *) + raise (Failure "Unexpected") - visitor#visit_ty () ty + method! visit_Clause _ id = subst.tr_subst id + method! visit_Self _ = subst.tr_self + end -let rty_substitute (rsubst : T.RegionId.id -> T.RegionId.id) - (tsubst : T.TypeVarId.id -> T.rty) - (cgsubst : T.ConstGenericVarId.id -> T.const_generic) (ty : T.rty) : T.rty = - let rsubst r = - match r with T.Static -> T.Static | T.Var rid -> T.Var (rsubst rid) - in - ty_substitute rsubst tsubst cgsubst ty +(** Substitute types variables and regions in a type. -let ety_substitute (tsubst : T.TypeVarId.id -> T.ety) - (cgsubst : T.ConstGenericVarId.id -> T.const_generic) (ty : T.ety) : T.ety = - let rsubst r = r in - ty_substitute rsubst tsubst cgsubst ty + **IMPORTANT**: this doesn't normalize the types. + *) +let ty_substitute (subst : ('r1, 'r2) subst) (ty : 'r1 T.ty) : 'r2 T.ty = + let visitor = ty_substitute_visitor subst in + visitor#visit_ty () ty + +(** **IMPORTANT**: this doesn't normalize the types. *) +let trait_ref_substitute (subst : ('r1, 'r2) subst) (tr : 'r1 T.trait_ref) : + 'r2 T.trait_ref = + let visitor = ty_substitute_visitor subst in + visitor#visit_trait_ref () tr + +(** **IMPORTANT**: this doesn't normalize the types. *) +let generic_args_substitute (subst : ('r1, 'r2) subst) (g : 'r1 T.generic_args) + : 'r2 T.generic_args = + let visitor = ty_substitute_visitor subst in + visitor#visit_generic_args () g + +let erase_regions_subst : ('r, T.erased_region) subst = + { + r_subst = (fun _ -> T.Erased); + ty_subst = (fun vid -> T.TypeVar vid); + cg_subst = (fun id -> T.ConstGenericVar id); + tr_subst = (fun id -> T.Clause id); + tr_self = T.Self; + } (** Convert an {!T.rty} to an {!T.ety} by erasing the region variables *) -let erase_regions (ty : T.rty) : T.ety = - ty_substitute - (fun _ -> T.Erased) - (fun vid -> T.TypeVar vid) - (fun id -> T.ConstGenericVar id) - ty +let erase_regions (ty : 'r T.ty) : T.ety = ty_substitute erase_regions_subst ty + +let trait_ref_erase_regions (tr : 'r T.trait_ref) : T.etrait_ref = + trait_ref_substitute erase_regions_subst tr (** Generate fresh regions for region variables. @@ -78,18 +97,20 @@ let fresh_regions_with_substs (region_vars : T.region_var list) : (* Generate the substitution from region var id to region *) let rid_subst id = T.RegionVarId.Map.find id rid_map in (* Generate the substitution from region to region *) - let rsubst r = + let r_subst r = match r with T.Static -> T.Static | T.Var id -> T.Var (rid_subst id) in (* Return *) - (fresh_region_ids, rid_subst, rsubst) + (fresh_region_ids, rid_subst, r_subst) -(** Erase the regions in a type and substitute the type variables *) -let erase_regions_substitute_types (tsubst : T.TypeVarId.id -> T.ety) - (cgsubst : T.ConstGenericVarId.id -> T.const_generic) - (ty : 'r T.region T.ty) : T.ety = - let rsubst (_ : 'r T.region) : T.erased_region = T.Erased in - ty_substitute rsubst tsubst cgsubst ty +(** Erase the regions in a type and perform a substitution *) +let erase_regions_substitute_types (ty_subst : T.TypeVarId.id -> T.ety) + (cg_subst : T.ConstGenericVarId.id -> T.const_generic) + (tr_subst : T.TraitClauseId.id -> T.etrait_instance_id) + (tr_self : T.etrait_instance_id) (ty : 'r T.ty) : T.ety = + let r_subst (_ : 'r) : T.erased_region = T.Erased in + let subst = { r_subst; ty_subst; cg_subst; tr_subst; tr_self } in + ty_substitute subst ty (** Create a region substitution from a list of region variable ids and a list of regions (with which to substitute the region variable ids *) @@ -146,16 +167,81 @@ let make_const_generic_subst_from_vars (vars : T.const_generic_var list) (List.map (fun (x : T.const_generic_var) -> x.T.index) vars) cgs -(** Instantiate the type variables in an ADT definition, and return, for - every variant, the list of the types of its fields *) -let type_decl_get_instantiated_variants_fields_rtypes (def : T.type_decl) - (regions : T.RegionId.id T.region list) (types : T.rty list) - (cgs : T.const_generic list) : (T.VariantId.id option * T.rty list) list = - let r_subst = make_region_subst_from_vars def.T.region_params regions in - let ty_subst = make_type_subst_from_vars def.T.type_params types in +(** Create a trait substitution from a list of trait clause ids and a list of + trait refs *) +let make_trait_subst (clause_ids : T.TraitClauseId.id list) + (trs : 'r T.trait_ref list) : T.TraitClauseId.id -> 'r T.trait_instance_id = + let ls = List.combine clause_ids trs in + let mp = + List.fold_left + (fun mp (k, v) -> T.TraitClauseId.Map.add k (T.TraitRef v) mp) + T.TraitClauseId.Map.empty ls + in + fun id -> T.TraitClauseId.Map.find id mp + +let make_trait_subst_from_clauses (clauses : T.trait_clause list) + (trs : 'r T.trait_ref list) : T.TraitClauseId.id -> 'r T.trait_instance_id = + make_trait_subst + (List.map (fun (x : T.trait_clause) -> x.T.clause_id) clauses) + trs + +let make_subst_from_generics (params : T.generic_params) + (args : 'r T.region T.generic_args) + (tr_self : 'r T.region T.trait_instance_id) : + (T.region_var_id T.region, 'r T.region) subst = + let r_subst = make_region_subst_from_vars params.T.regions args.T.regions in + let ty_subst = make_type_subst_from_vars params.T.types args.T.types in let cg_subst = - make_const_generic_subst_from_vars def.T.const_generic_params cgs + make_const_generic_subst_from_vars params.T.const_generics + args.T.const_generics + in + let tr_subst = + make_trait_subst_from_clauses params.T.trait_clauses args.T.trait_refs + in + { r_subst; ty_subst; cg_subst; tr_subst; tr_self } + +let make_subst_from_generics_no_regions : + 'r. + T.generic_params -> + 'r T.generic_args -> + 'r T.trait_instance_id -> + (T.region_var_id T.region, 'r) subst = + fun params args tr_self -> + let r_subst _ = raise (Failure "Unexpected region") in + let ty_subst = make_type_subst_from_vars params.T.types args.T.types in + let cg_subst = + make_const_generic_subst_from_vars params.T.const_generics + args.T.const_generics + in + let tr_subst = + make_trait_subst_from_clauses params.T.trait_clauses args.T.trait_refs + in + { r_subst; ty_subst; cg_subst; tr_subst; tr_self } + +let make_esubst_from_generics (params : T.generic_params) + (generics : T.egeneric_args) (tr_self : T.etrait_instance_id) = + let r_subst _ = T.Erased in + let ty_subst = make_type_subst_from_vars params.types generics.T.types in + let cg_subst = + make_const_generic_subst_from_vars params.const_generics + generics.T.const_generics + in + let tr_subst = + make_trait_subst_from_clauses params.trait_clauses generics.T.trait_refs in + { r_subst; ty_subst; cg_subst; tr_subst; tr_self } + +(** Instantiate the type variables in an ADT definition, and return, for + every variant, the list of the types of its fields. + + **IMPORTANT**: this function doesn't normalize the types, you may want to + use the [AssociatedTypes] equivalent instead. +*) +let type_decl_get_instantiated_variants_fields_rtypes (def : T.type_decl) + (generics : T.rgeneric_args) : (T.VariantId.id option * T.rty list) list = + (* There shouldn't be any reference to Self *) + let tr_self = T.UnknownTrait __FUNCTION__ in + let subst = make_subst_from_generics def.T.generics generics tr_self in let (variants_fields : (T.VariantId.id option * T.field list) list) = match def.T.kind with | T.Enum variants -> @@ -171,191 +257,220 @@ let type_decl_get_instantiated_variants_fields_rtypes (def : T.type_decl) in List.map (fun (id, fields) -> - ( id, - List.map - (fun f -> ty_substitute r_subst ty_subst cg_subst f.T.field_ty) - fields )) + (id, List.map (fun f -> ty_substitute subst f.T.field_ty) fields)) variants_fields (** Instantiate the type variables in an ADT definition, and return the list - of types of the fields for the chosen variant *) + of types of the fields for the chosen variant. + + **IMPORTANT**: this function doesn't normalize the types, you may want to + use the [AssociatedTypes] equivalent instead. +*) let type_decl_get_instantiated_field_rtypes (def : T.type_decl) - (opt_variant_id : T.VariantId.id option) - (regions : T.RegionId.id T.region list) (types : T.rty list) - (cgs : T.const_generic list) : T.rty list = - let r_subst = make_region_subst_from_vars def.T.region_params regions in - let ty_subst = make_type_subst_from_vars def.T.type_params types in - let cg_subst = - make_const_generic_subst_from_vars def.T.const_generic_params cgs - in + (opt_variant_id : T.VariantId.id option) (generics : T.rgeneric_args) : + T.rty list = + (* For now, check that there are no clauses - otherwise we might need + to normalize the types *) + assert (def.generics.trait_clauses = []); + (* There shouldn't be any reference to Self *) + let tr_self = T.UnknownTrait __FUNCTION__ in + let subst = make_subst_from_generics def.T.generics generics tr_self in let fields = TU.type_decl_get_fields def opt_variant_id in - List.map - (fun f -> ty_substitute r_subst ty_subst cg_subst f.T.field_ty) - fields + List.map (fun f -> ty_substitute subst f.T.field_ty) fields (** Return the types of the properly instantiated ADT's variant, provided a - context *) + context. + + **IMPORTANT**: this function doesn't normalize the types, you may want to + use the [AssociatedTypes] equivalent instead. +*) let ctx_adt_get_instantiated_field_rtypes (ctx : C.eval_ctx) (def_id : T.TypeDeclId.id) (opt_variant_id : T.VariantId.id option) - (regions : T.RegionId.id T.region list) (types : T.rty list) - (cgs : T.const_generic list) : T.rty list = + (generics : T.rgeneric_args) : T.rty list = let def = C.ctx_lookup_type_decl ctx def_id in - type_decl_get_instantiated_field_rtypes def opt_variant_id regions types cgs + type_decl_get_instantiated_field_rtypes def opt_variant_id generics (** Return the types of the properly instantiated ADT value (note that - here, ADT is understood in its broad meaning: ADT, assumed value or tuple) *) + here, ADT is understood in its broad meaning: ADT, assumed value or tuple). + + **IMPORTANT**: this function doesn't normalize the types, you may want to + use the [AssociatedTypes] equivalent instead. + *) let ctx_adt_value_get_instantiated_field_rtypes (ctx : C.eval_ctx) - (adt : V.adt_value) (id : T.type_id) - (region_params : T.RegionId.id T.region list) (type_params : T.rty list) - (cg_params : T.const_generic list) : T.rty list = + (adt : V.adt_value) (id : T.type_id) (generics : T.rgeneric_args) : + T.rty list = match id with | T.AdtId id -> (* Retrieve the types of the fields *) - ctx_adt_get_instantiated_field_rtypes ctx id adt.V.variant_id - region_params type_params cg_params + ctx_adt_get_instantiated_field_rtypes ctx id adt.V.variant_id generics | T.Tuple -> - assert (List.length region_params = 0); - type_params + assert (generics.regions = []); + generics.types | T.Assumed aty -> ( match aty with - | T.Box | T.Vec -> - assert (List.length region_params = 0); - assert (List.length type_params = 1); - assert (List.length cg_params = 0); - type_params - | T.Option -> - assert (List.length region_params = 0); - assert (List.length type_params = 1); - assert (List.length cg_params = 0); - if adt.V.variant_id = Some T.option_some_id then type_params - else if adt.V.variant_id = Some T.option_none_id then [] - else raise (Failure "Unreachable") - | T.Range -> - assert (List.length region_params = 0); - assert (List.length type_params = 1); - assert (List.length cg_params = 0); - type_params + | T.Box -> + assert (generics.regions = []); + assert (List.length generics.types = 1); + assert (generics.const_generics = []); + generics.types | T.Array | T.Slice | T.Str -> (* Those types don't have fields *) raise (Failure "Unreachable")) (** Instantiate the type variables in an ADT definition, and return the list - of types of the fields for the chosen variant *) + of types of the fields for the chosen variant. + + **IMPORTANT**: this function doesn't normalize the types, you may want to + use the [AssociatedTypes] equivalent instead. +*) let type_decl_get_instantiated_field_etypes (def : T.type_decl) - (opt_variant_id : T.VariantId.id option) (types : T.ety list) - (cgs : T.const_generic list) : T.ety list = - let ty_subst = make_type_subst_from_vars def.T.type_params types in - let cg_subst = - make_const_generic_subst_from_vars def.T.const_generic_params cgs + (opt_variant_id : T.VariantId.id option) (generics : T.egeneric_args) : + T.ety list = + (* For now, check that there are no clauses - otherwise we might need + to normalize the types *) + assert (def.generics.trait_clauses = []); + (* There shouldn't be any reference to Self *) + let tr_self : T.erased_region T.trait_instance_id = + T.UnknownTrait __FUNCTION__ + in + let { r_subst = _; ty_subst; cg_subst; tr_subst; tr_self } = + make_esubst_from_generics def.T.generics generics tr_self in let fields = TU.type_decl_get_fields def opt_variant_id in List.map - (fun f -> erase_regions_substitute_types ty_subst cg_subst f.T.field_ty) + (fun (f : T.field) -> + erase_regions_substitute_types ty_subst cg_subst tr_subst tr_self + f.T.field_ty) fields (** Return the types of the properly instantiated ADT's variant, provided a - context *) + context. + + **IMPORTANT**: this function doesn't normalize the types, you may want to + use the [AssociatedTypes] equivalent instead. + *) let ctx_adt_get_instantiated_field_etypes (ctx : C.eval_ctx) (def_id : T.TypeDeclId.id) (opt_variant_id : T.VariantId.id option) - (types : T.ety list) (cgs : T.const_generic list) : T.ety list = + (generics : T.egeneric_args) : T.ety list = let def = C.ctx_lookup_type_decl ctx def_id in - type_decl_get_instantiated_field_etypes def opt_variant_id types cgs + type_decl_get_instantiated_field_etypes def opt_variant_id generics -let statement_substitute_visitor (tsubst : T.TypeVarId.id -> T.ety) - (cgsubst : T.ConstGenericVarId.id -> T.const_generic) = +let statement_substitute_visitor + (subst : (T.erased_region, T.erased_region) subst) = + (* Keep in synch with [ty_substitute_visitor] *) object inherit [_] A.map_statement - method! visit_ety _ ty = ety_substitute tsubst cgsubst ty - method! visit_ConstGenericVar _ id = cgsubst id + method! visit_'r _ r = subst.r_subst r + method! visit_TypeVar _ id = subst.ty_subst id + + method! visit_type_var_id _ _ = + (* We should never get here because we reimplemented [visit_TypeVar] *) + raise (Failure "Unexpected") + + method! visit_ConstGenericVar _ id = subst.cg_subst id method! visit_const_generic_var_id _ _ = (* We should never get here because we reimplemented [visit_Var] *) raise (Failure "Unexpected") + + method! visit_Clause _ id = subst.tr_subst id + method! visit_Self _ = subst.tr_self end (** Apply a type substitution to a place *) -let place_substitute (tsubst : T.TypeVarId.id -> T.ety) - (cgsubst : T.ConstGenericVarId.id -> T.const_generic) (p : E.place) : - E.place = +let place_substitute (subst : (T.erased_region, T.erased_region) subst) + (p : E.place) : E.place = (* There is in fact nothing to do *) - (statement_substitute_visitor tsubst cgsubst)#visit_place () p + (statement_substitute_visitor subst)#visit_place () p (** Apply a type substitution to an operand *) -let operand_substitute (tsubst : T.TypeVarId.id -> T.ety) - (cgsubst : T.ConstGenericVarId.id -> T.const_generic) (op : E.operand) : - E.operand = - (statement_substitute_visitor tsubst cgsubst)#visit_operand () op +let operand_substitute (subst : (T.erased_region, T.erased_region) subst) + (op : E.operand) : E.operand = + (statement_substitute_visitor subst)#visit_operand () op (** Apply a type substitution to an rvalue *) -let rvalue_substitute (tsubst : T.TypeVarId.id -> T.ety) - (cgsubst : T.ConstGenericVarId.id -> T.const_generic) (rv : E.rvalue) : - E.rvalue = - (statement_substitute_visitor tsubst cgsubst)#visit_rvalue () rv +let rvalue_substitute (subst : (T.erased_region, T.erased_region) subst) + (rv : E.rvalue) : E.rvalue = + (statement_substitute_visitor subst)#visit_rvalue () rv (** Apply a type substitution to an assertion *) -let assertion_substitute (tsubst : T.TypeVarId.id -> T.ety) - (cgsubst : T.ConstGenericVarId.id -> T.const_generic) (a : A.assertion) : - A.assertion = - (statement_substitute_visitor tsubst cgsubst)#visit_assertion () a +let assertion_substitute (subst : (T.erased_region, T.erased_region) subst) + (a : A.assertion) : A.assertion = + (statement_substitute_visitor subst)#visit_assertion () a (** Apply a type substitution to a call *) -let call_substitute (tsubst : T.TypeVarId.id -> T.ety) - (cgsubst : T.ConstGenericVarId.id -> T.const_generic) (call : A.call) : - A.call = - (statement_substitute_visitor tsubst cgsubst)#visit_call () call +let call_substitute (subst : (T.erased_region, T.erased_region) subst) + (call : A.call) : A.call = + (statement_substitute_visitor subst)#visit_call () call (** Apply a type substitution to a statement *) -let statement_substitute (tsubst : T.TypeVarId.id -> T.ety) - (cgsubst : T.ConstGenericVarId.id -> T.const_generic) (st : A.statement) : - A.statement = - (statement_substitute_visitor tsubst cgsubst)#visit_statement () st +let statement_substitute (subst : (T.erased_region, T.erased_region) subst) + (st : A.statement) : A.statement = + (statement_substitute_visitor subst)#visit_statement () st (** Apply a type substitution to a function body. Return the local variables and the body. *) -let fun_body_substitute_in_body (tsubst : T.TypeVarId.id -> T.ety) - (cgsubst : T.ConstGenericVarId.id -> T.const_generic) (body : A.fun_body) : +let fun_body_substitute_in_body + (subst : (T.erased_region, T.erased_region) subst) (body : A.fun_body) : A.var list * A.statement = - let rsubst r = r in let locals = List.map - (fun (v : A.var) -> - { v with A.var_ty = ty_substitute rsubst tsubst cgsubst v.A.var_ty }) + (fun (v : A.var) -> { v with A.var_ty = ty_substitute subst v.A.var_ty }) body.A.locals in - let body = statement_substitute tsubst cgsubst body.body in + let body = statement_substitute subst body.body in (locals, body) -(** Substitute a function signature *) +let trait_type_constraint_substitute (subst : ('r1, 'r2) subst) + (ttc : 'r1 T.trait_type_constraint) : 'r2 T.trait_type_constraint = + let { T.trait_ref; generics; type_name; ty } = ttc in + let visitor = ty_substitute_visitor subst in + let trait_ref = visitor#visit_trait_ref () trait_ref in + let generics = visitor#visit_generic_args () generics in + let ty = visitor#visit_ty () ty in + { T.trait_ref; generics; type_name; ty } + +(** Substitute a function signature. + + **IMPORTANT:** this function doesn't normalize the types. + *) let substitute_signature (asubst : T.RegionGroupId.id -> V.AbstractionId.id) - (rsubst : T.RegionVarId.id -> T.RegionId.id) - (tsubst : T.TypeVarId.id -> T.rty) - (cgsubst : T.ConstGenericVarId.id -> T.const_generic) (sg : A.fun_sig) : - A.inst_fun_sig = - let rsubst' (r : T.RegionVarId.id T.region) : T.RegionId.id T.region = - match r with T.Static -> T.Static | T.Var rid -> T.Var (rsubst rid) + (r_subst : T.RegionVarId.id -> T.RegionId.id) + (ty_subst : T.TypeVarId.id -> T.rty) + (cg_subst : T.ConstGenericVarId.id -> T.const_generic) + (tr_subst : T.TraitClauseId.id -> T.rtrait_instance_id) + (tr_self : T.rtrait_instance_id) (sg : A.fun_sig) : A.inst_fun_sig = + let r_subst' (r : T.RegionVarId.id T.region) : T.RegionId.id T.region = + match r with T.Static -> T.Static | T.Var rid -> T.Var (r_subst rid) in - let inputs = List.map (ty_substitute rsubst' tsubst cgsubst) sg.A.inputs in - let output = ty_substitute rsubst' tsubst cgsubst sg.A.output in + let subst = { r_subst = r_subst'; ty_subst; cg_subst; tr_subst; tr_self } in + let inputs = List.map (ty_substitute subst) sg.A.inputs in + let output = ty_substitute subst sg.A.output in let subst_region_group (rg : T.region_var_group) : A.abs_region_group = let id = asubst rg.id in - let regions = List.map rsubst rg.regions in + let regions = List.map r_subst rg.regions in let parents = List.map asubst rg.parents in { id; regions; parents } in let regions_hierarchy = List.map subst_region_group sg.A.regions_hierarchy in - { A.regions_hierarchy; inputs; output } + let trait_type_constraints = + List.map + (trait_type_constraint_substitute subst) + sg.preds.trait_type_constraints + in + { A.inputs; output; regions_hierarchy; trait_type_constraints } -(** Substitute type variable identifiers in a type *) -let ty_substitute_ids (tsubst : T.TypeVarId.id -> T.TypeVarId.id) - (cgsubst : T.ConstGenericVarId.id -> T.ConstGenericVarId.id) (ty : 'r T.ty) +(** Substitute variable identifiers in a type *) +let ty_substitute_ids (ty_subst : T.TypeVarId.id -> T.TypeVarId.id) + (cg_subst : T.ConstGenericVarId.id -> T.ConstGenericVarId.id) (ty : 'r T.ty) : 'r T.ty = let open T in let visitor = object inherit [_] map_ty method visit_'r _ r = r - method! visit_type_var_id _ id = tsubst id - method! visit_const_generic_var_id _ id = cgsubst id + method! visit_type_var_id _ id = ty_subst id + method! visit_const_generic_var_id _ id = cg_subst id end in @@ -371,10 +486,10 @@ let ty_substitute_ids (tsubst : T.TypeVarId.id -> T.TypeVarId.id) [visit_'r] if we define a class which visits objects of types [ety] and [rty] while inheriting a class which visit [ty]... *) -let subst_ids_visitor (rsubst : T.RegionId.id -> T.RegionId.id) +let subst_ids_visitor (r_subst : T.RegionId.id -> T.RegionId.id) (rvsubst : T.RegionVarId.id -> T.RegionVarId.id) - (tsubst : T.TypeVarId.id -> T.TypeVarId.id) - (cgsubst : T.ConstGenericVarId.id -> T.ConstGenericVarId.id) + (ty_subst : T.TypeVarId.id -> T.TypeVarId.id) + (cg_subst : T.ConstGenericVarId.id -> T.ConstGenericVarId.id) (ssubst : V.SymbolicValueId.id -> V.SymbolicValueId.id) (bsubst : V.BorrowId.id -> V.BorrowId.id) (asubst : V.AbstractionId.id -> V.AbstractionId.id) = @@ -383,10 +498,10 @@ let subst_ids_visitor (rsubst : T.RegionId.id -> T.RegionId.id) inherit [_] T.map_ty method visit_'r _ r = - match r with T.Static -> T.Static | T.Var rid -> T.Var (rsubst rid) + match r with T.Static -> T.Static | T.Var rid -> T.Var (r_subst rid) - method! visit_type_var_id _ id = tsubst id - method! visit_const_generic_var_id _ id = cgsubst id + method! visit_type_var_id _ id = ty_subst id + method! visit_const_generic_var_id _ id = cg_subst id end in @@ -395,7 +510,7 @@ let subst_ids_visitor (rsubst : T.RegionId.id -> T.RegionId.id) inherit [_] C.map_env method! visit_borrow_id _ bid = bsubst bid method! visit_loan_id _ bid = bsubst bid - method! visit_ety _ ty = ty_substitute_ids tsubst cgsubst ty + method! visit_ety _ ty = ty_substitute_ids ty_subst cg_subst ty method! visit_rty env ty = subst_rty#visit_ty env ty method! visit_symbolic_value_id _ id = ssubst id @@ -405,7 +520,7 @@ let subst_ids_visitor (rsubst : T.RegionId.id -> T.RegionId.id) (** We *do* visit meta-values *) method! visit_mvalue env v = self#visit_typed_value env v - method! visit_region_id _ id = rsubst id + method! visit_region_id _ id = r_subst id method! visit_region_var_id _ id = rvsubst id method! visit_abstraction_id _ id = asubst id end @@ -425,20 +540,20 @@ let subst_ids_visitor (rsubst : T.RegionId.id -> T.RegionId.id) method visit_env (env : C.env) : C.env = visitor#visit_env () env end -let typed_value_subst_ids (rsubst : T.RegionId.id -> T.RegionId.id) +let typed_value_subst_ids (r_subst : T.RegionId.id -> T.RegionId.id) (rvsubst : T.RegionVarId.id -> T.RegionVarId.id) - (tsubst : T.TypeVarId.id -> T.TypeVarId.id) - (cgsubst : T.ConstGenericVarId.id -> T.ConstGenericVarId.id) + (ty_subst : T.TypeVarId.id -> T.TypeVarId.id) + (cg_subst : T.ConstGenericVarId.id -> T.ConstGenericVarId.id) (ssubst : V.SymbolicValueId.id -> V.SymbolicValueId.id) (bsubst : V.BorrowId.id -> V.BorrowId.id) (v : V.typed_value) : V.typed_value = let asubst _ = raise (Failure "Unreachable") in - (subst_ids_visitor rsubst rvsubst tsubst cgsubst ssubst bsubst asubst) + (subst_ids_visitor r_subst rvsubst ty_subst cg_subst ssubst bsubst asubst) #visit_typed_value v -let typed_value_subst_rids (rsubst : T.RegionId.id -> T.RegionId.id) +let typed_value_subst_rids (r_subst : T.RegionId.id -> T.RegionId.id) (v : V.typed_value) : V.typed_value = - typed_value_subst_ids rsubst + typed_value_subst_ids r_subst (fun x -> x) (fun x -> x) (fun x -> x) @@ -446,41 +561,41 @@ let typed_value_subst_rids (rsubst : T.RegionId.id -> T.RegionId.id) (fun x -> x) v -let typed_avalue_subst_ids (rsubst : T.RegionId.id -> T.RegionId.id) +let typed_avalue_subst_ids (r_subst : T.RegionId.id -> T.RegionId.id) (rvsubst : T.RegionVarId.id -> T.RegionVarId.id) - (tsubst : T.TypeVarId.id -> T.TypeVarId.id) - (cgsubst : T.ConstGenericVarId.id -> T.ConstGenericVarId.id) + (ty_subst : T.TypeVarId.id -> T.TypeVarId.id) + (cg_subst : T.ConstGenericVarId.id -> T.ConstGenericVarId.id) (ssubst : V.SymbolicValueId.id -> V.SymbolicValueId.id) (bsubst : V.BorrowId.id -> V.BorrowId.id) (v : V.typed_avalue) : V.typed_avalue = let asubst _ = raise (Failure "Unreachable") in - (subst_ids_visitor rsubst rvsubst tsubst cgsubst ssubst bsubst asubst) + (subst_ids_visitor r_subst rvsubst ty_subst cg_subst ssubst bsubst asubst) #visit_typed_avalue v -let abs_subst_ids (rsubst : T.RegionId.id -> T.RegionId.id) +let abs_subst_ids (r_subst : T.RegionId.id -> T.RegionId.id) (rvsubst : T.RegionVarId.id -> T.RegionVarId.id) - (tsubst : T.TypeVarId.id -> T.TypeVarId.id) - (cgsubst : T.ConstGenericVarId.id -> T.ConstGenericVarId.id) + (ty_subst : T.TypeVarId.id -> T.TypeVarId.id) + (cg_subst : T.ConstGenericVarId.id -> T.ConstGenericVarId.id) (ssubst : V.SymbolicValueId.id -> V.SymbolicValueId.id) (bsubst : V.BorrowId.id -> V.BorrowId.id) (asubst : V.AbstractionId.id -> V.AbstractionId.id) (x : V.abs) : V.abs = - (subst_ids_visitor rsubst rvsubst tsubst cgsubst ssubst bsubst asubst) + (subst_ids_visitor r_subst rvsubst ty_subst cg_subst ssubst bsubst asubst) #visit_abs x -let env_subst_ids (rsubst : T.RegionId.id -> T.RegionId.id) +let env_subst_ids (r_subst : T.RegionId.id -> T.RegionId.id) (rvsubst : T.RegionVarId.id -> T.RegionVarId.id) - (tsubst : T.TypeVarId.id -> T.TypeVarId.id) - (cgsubst : T.ConstGenericVarId.id -> T.ConstGenericVarId.id) + (ty_subst : T.TypeVarId.id -> T.TypeVarId.id) + (cg_subst : T.ConstGenericVarId.id -> T.ConstGenericVarId.id) (ssubst : V.SymbolicValueId.id -> V.SymbolicValueId.id) (bsubst : V.BorrowId.id -> V.BorrowId.id) (asubst : V.AbstractionId.id -> V.AbstractionId.id) (x : C.env) : C.env = - (subst_ids_visitor rsubst rvsubst tsubst cgsubst ssubst bsubst asubst) + (subst_ids_visitor r_subst rvsubst ty_subst cg_subst ssubst bsubst asubst) #visit_env x -let typed_avalue_subst_rids (rsubst : T.RegionId.id -> T.RegionId.id) +let typed_avalue_subst_rids (r_subst : T.RegionId.id -> T.RegionId.id) (x : V.typed_avalue) : V.typed_avalue = let asubst _ = raise (Failure "Unreachable") in - (subst_ids_visitor rsubst + (subst_ids_visitor r_subst (fun x -> x) (fun x -> x) (fun x -> x) @@ -490,9 +605,9 @@ let typed_avalue_subst_rids (rsubst : T.RegionId.id -> T.RegionId.id) #visit_typed_avalue x -let env_subst_rids (rsubst : T.RegionId.id -> T.RegionId.id) (x : C.env) : C.env - = - (subst_ids_visitor rsubst +let env_subst_rids (r_subst : T.RegionId.id -> T.RegionId.id) (x : C.env) : + C.env = + (subst_ids_visitor r_subst (fun x -> x) (fun x -> x) (fun x -> x) diff --git a/compiler/SymbolicAst.ml b/compiler/SymbolicAst.ml index 7dc94dcd..4df8fec7 100644 --- a/compiler/SymbolicAst.ml +++ b/compiler/SymbolicAst.ml @@ -29,7 +29,7 @@ type mplace = { [@@deriving show] type call_id = - | Fun of A.fun_id * V.FunCallId.id + | Fun of A.fun_id_or_trait_method_ref * V.FunCallId.id (** A "regular" function (i.e., a function which is not a primitive operation) *) | Unop of E.unop | Binop of E.binop @@ -43,10 +43,7 @@ type call = { borrows (we need to perform lookups). *) abstractions : V.AbstractionId.id list; - (* TODO: rename to "...args" *) - type_params : T.ety list; - (* TODO: rename to "...args" *) - const_generic_params : T.const_generic list; + generics : T.egeneric_args; args : V.typed_value list; args_places : mplace option list; (** Meta information *) dest : V.symbolic_value; @@ -79,6 +76,9 @@ class ['self] iter_expression_base = method visit_loop_id : 'env -> V.loop_id -> unit = fun _ _ -> () method visit_variant_id : 'env -> variant_id -> unit = fun _ _ -> () + method visit_const_generic_var_id : 'env -> T.const_generic_var_id -> unit = + fun _ _ -> () + method visit_symbolic_value_id : 'env -> V.symbolic_value_id -> unit = fun _ _ -> () @@ -120,6 +120,9 @@ class ['self] iter_expression_base = method visit_symbolic_expansion : 'env -> V.symbolic_expansion -> unit = fun _ _ -> () + + method visit_etrait_ref : 'env -> T.etrait_ref -> unit = fun _ _ -> () + method visit_egeneric_args : 'env -> T.egeneric_args -> unit = fun _ _ -> () end (** **Rem.:** here, {!expression} is not at all equivalent to the expressions @@ -171,14 +174,15 @@ type expression = * expression (** We introduce a new symbolic value, equal to some other value. - This is used for instance when reorganizing the environment to compute - fixed points: we duplicate some shared symbolic values to destructure - the shared values, in order to make the environment a bit more general - (while losing precision of course). + This is used for instance when reorganizing the environment to compute + fixed points: we duplicate some shared symbolic values to destructure + the shared values, in order to make the environment a bit more general + (while losing precision of course). We also use it to introduce symbolic + values when evaluating constant generics, or trait constants. - The context is the evaluation context from before introducing the new - value. It has the same purpose as for the {!Return} case. - *) + The context is the evaluation context from before introducing the new + value. It has the same purpose as for the {!Return} case. + *) | ForwardEnd of Contexts.eval_ctx * V.typed_value symbolic_value_id_map option @@ -253,6 +257,11 @@ and value_aggregate = | SingleValue of V.typed_value (** Regular case *) | Array of V.typed_value list (** This is used when introducing array aggregates *) + | ConstGenericValue of T.const_generic_var_id + (** This is used when evaluating a const generic value: in the interpreter, + we introduce a fresh symbolic value. *) + | TraitConstValue of T.etrait_ref * T.egeneric_args * string + (** A trait constant value *) [@@deriving show, visitors diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml index 3512270a..2ce8c706 100644 --- a/compiler/SymbolicToPure.ml +++ b/compiler/SymbolicToPure.ml @@ -4,6 +4,7 @@ open Pure open PureUtils module Id = Identifiers module C = Contexts +module A = LlbcAst module S = SymbolicAst module TA = TypesAnalysis module L = Logging @@ -52,6 +53,9 @@ type fun_context = { type global_context = { llbc_global_decls : A.global_decl A.GlobalDeclId.Map.t } [@@deriving show] +type trait_decls_context = A.trait_decl A.TraitDeclId.Map.t [@@deriving show] +type trait_impls_context = A.trait_impl A.TraitImplId.Map.t [@@deriving show] + (** Whenever we translate a function call or an ended abstraction, we store the related information (this is useful when translating ended children abstractions). @@ -106,8 +110,7 @@ type loop_info = { loop_id : LoopId.id; input_vars : var list; input_svl : V.symbolic_value list; - type_args : ty list; - const_generic_args : const_generic list; + generics : generic_args; forward_inputs : texpression list option; (** The forward inputs are initialized at [None] *) forward_output_no_state_no_result : var option; @@ -120,6 +123,8 @@ type bs_ctx = { type_context : type_context; fun_context : fun_context; global_context : global_context; + trait_decls_ctx : trait_decls_context; + trait_impls_ctx : trait_impls_context; fun_decl : A.fun_decl; bid : T.RegionGroupId.id option; (** TODO: rename *) sg : fun_sig; @@ -201,34 +206,11 @@ type bs_ctx = { } [@@deriving show] -let type_check_pattern (ctx : bs_ctx) (v : typed_pattern) : unit = - let env = VarId.Map.empty in - let ctx = - { - PureTypeCheck.type_decls = ctx.type_context.type_decls; - global_decls = ctx.global_context.llbc_global_decls; - env; - } - in - let _ = PureTypeCheck.check_typed_pattern ctx v in - () - -let type_check_texpression (ctx : bs_ctx) (e : texpression) : unit = - let env = VarId.Map.empty in - let ctx = - { - PureTypeCheck.type_decls = ctx.type_context.type_decls; - global_decls = ctx.global_context.llbc_global_decls; - env; - } - in - PureTypeCheck.check_texpression ctx e - (* TODO: move *) let bs_ctx_to_ast_formatter (ctx : bs_ctx) : Print.Ast.ast_formatter = Print.Ast.decls_and_fun_decl_to_ast_formatter ctx.type_context.llbc_type_decls ctx.fun_context.llbc_fun_decls ctx.global_context.llbc_global_decls - ctx.fun_decl + ctx.trait_decls_ctx ctx.trait_impls_ctx ctx.fun_decl let bs_ctx_to_ctx_formatter (ctx : bs_ctx) : Print.Contexts.ctx_formatter = let rvar_to_string = Print.Types.region_var_id_to_string in @@ -246,16 +228,25 @@ let bs_ctx_to_ctx_formatter (ctx : bs_ctx) : Print.Contexts.ctx_formatter = adt_variant_to_string = ast_fmt.adt_variant_to_string; var_id_to_string; adt_field_names = ast_fmt.adt_field_names; + trait_decl_id_to_string = ast_fmt.trait_decl_id_to_string; + trait_impl_id_to_string = ast_fmt.trait_impl_id_to_string; + trait_clause_id_to_string = ast_fmt.trait_clause_id_to_string; } let bs_ctx_to_pp_ast_formatter (ctx : bs_ctx) : PrintPure.ast_formatter = - let type_params = ctx.fun_decl.signature.type_params in - let cg_params = ctx.fun_decl.signature.const_generic_params in + let generics = ctx.fun_decl.signature.generics in let type_decls = ctx.type_context.llbc_type_decls in let fun_decls = ctx.fun_context.llbc_fun_decls in let global_decls = ctx.global_context.llbc_global_decls in - PrintPure.mk_ast_formatter type_decls fun_decls global_decls type_params - cg_params + PrintPure.mk_ast_formatter type_decls fun_decls global_decls + ctx.trait_decls_ctx ctx.trait_impls_ctx generics.types + generics.const_generics + +let ctx_egeneric_args_to_string (ctx : bs_ctx) (args : T.egeneric_args) : string + = + let fmt = bs_ctx_to_ctx_formatter ctx in + let fmt = Print.PC.ctx_to_etype_formatter fmt in + Print.PT.egeneric_args_to_string fmt args let symbolic_value_to_string (ctx : bs_ctx) (sv : V.symbolic_value) : string = let fmt = bs_ctx_to_ctx_formatter ctx in @@ -277,12 +268,11 @@ let rty_to_string (ctx : bs_ctx) (ty : T.rty) : string = Print.PT.rty_to_string fmt ty let type_decl_to_string (ctx : bs_ctx) (def : type_decl) : string = - let type_params = def.type_params in - let cg_params = def.const_generic_params in let type_decls = ctx.type_context.llbc_type_decls in let global_decls = ctx.global_context.llbc_global_decls in let fmt = - PrintPure.mk_type_formatter type_decls global_decls type_params cg_params + PrintPure.mk_type_formatter type_decls global_decls ctx.trait_decls_ctx + ctx.trait_impls_ctx def.generics.types def.generics.const_generics in PrintPure.type_decl_to_string fmt def @@ -291,26 +281,27 @@ let texpression_to_string (ctx : bs_ctx) (e : texpression) : string = PrintPure.texpression_to_string fmt false "" " " e let fun_sig_to_string (ctx : bs_ctx) (sg : fun_sig) : string = - let type_params = sg.type_params in - let cg_params = sg.const_generic_params in + let type_params = sg.generics.types in + let cg_params = sg.generics.const_generics in let type_decls = ctx.type_context.llbc_type_decls in let fun_decls = ctx.fun_context.llbc_fun_decls in let global_decls = ctx.global_context.llbc_global_decls in let fmt = - PrintPure.mk_ast_formatter type_decls fun_decls global_decls type_params - cg_params + PrintPure.mk_ast_formatter type_decls fun_decls global_decls + ctx.trait_decls_ctx ctx.trait_impls_ctx type_params cg_params in PrintPure.fun_sig_to_string fmt sg let fun_decl_to_string (ctx : bs_ctx) (def : Pure.fun_decl) : string = - let type_params = def.signature.type_params in - let cg_params = def.signature.const_generic_params in + let generics = def.signature.generics in + let type_params = generics.types in + let cg_params = generics.const_generics in let type_decls = ctx.type_context.llbc_type_decls in let fun_decls = ctx.fun_context.llbc_fun_decls in let global_decls = ctx.global_context.llbc_global_decls in let fmt = - PrintPure.mk_ast_formatter type_decls fun_decls global_decls type_params - cg_params + PrintPure.mk_ast_formatter type_decls fun_decls global_decls + ctx.trait_decls_ctx ctx.trait_impls_ctx type_params cg_params in PrintPure.fun_decl_to_string fmt def @@ -328,17 +319,18 @@ let abs_to_string (ctx : bs_ctx) (abs : V.abs) : string = Print.Values.abs_to_string fmt verbose indent indent_incr abs let get_instantiated_fun_sig (fun_id : A.fun_id) - (back_id : T.RegionGroupId.id option) (tys : ty list) - (cgs : const_generic list) (ctx : bs_ctx) : inst_fun_sig = + (back_id : T.RegionGroupId.id option) (generics : generic_args) + (ctx : bs_ctx) : inst_fun_sig = (* Lookup the non-instantiated function signature *) let sg = (RegularFunIdNotLoopMap.find (fun_id, back_id) ctx.fun_context.fun_sigs).sg in (* Create the substitution *) - let tsubst = make_type_subst sg.type_params tys in - let cgsubst = make_const_generic_subst sg.const_generic_params cgs in + (* There shouldn't be any reference to Self *) + let tr_self = UnknownTrait __FUNCTION__ in + let subst = make_subst_from_generics sg.generics generics tr_self in (* Apply *) - fun_sig_substitute tsubst cgsubst sg + fun_sig_substitute subst sg let bs_ctx_lookup_llbc_type_decl (id : TypeDeclId.id) (ctx : bs_ctx) : T.type_decl = @@ -351,77 +343,128 @@ let bs_ctx_lookup_llbc_fun_decl (id : A.FunDeclId.id) (ctx : bs_ctx) : (* TODO: move *) let bs_ctx_lookup_local_function_sig (def_id : A.FunDeclId.id) (back_id : T.RegionGroupId.id option) (ctx : bs_ctx) : fun_sig = - let id = (A.Regular def_id, back_id) in + let id = (E.Regular def_id, back_id) in (RegularFunIdNotLoopMap.find id ctx.fun_context.fun_sigs).sg -let bs_ctx_register_forward_call (call_id : V.FunCallId.id) (forward : S.call) - (args : texpression list) (ctx : bs_ctx) : bs_ctx = - let calls = ctx.calls in - assert (not (V.FunCallId.Map.mem call_id calls)); - let info = - { forward; forward_inputs = args; backwards = T.RegionGroupId.Map.empty } - in - let calls = V.FunCallId.Map.add call_id info calls in - { ctx with calls } - -(** [back_args]: the *additional* list of inputs received by the backward function *) -let bs_ctx_register_backward_call (abs : V.abs) (call_id : V.FunCallId.id) - (back_id : T.RegionGroupId.id) (back_args : texpression list) (ctx : bs_ctx) - : bs_ctx * fun_or_op_id = - (* Insert the abstraction in the call informations *) - let info = V.FunCallId.Map.find call_id ctx.calls in - assert (not (T.RegionGroupId.Map.mem back_id info.backwards)); - let backwards = - T.RegionGroupId.Map.add back_id (abs, back_args) info.backwards - in - let info = { info with backwards } in - let calls = V.FunCallId.Map.add call_id info ctx.calls in - (* Insert the abstraction in the abstractions map *) - let abstractions = ctx.abstractions in - assert (not (V.AbstractionId.Map.mem abs.abs_id abstractions)); - let abstractions = - V.AbstractionId.Map.add abs.abs_id (abs, back_args) abstractions - in - (* Retrieve the fun_id *) - let fun_id = - match info.forward.call_id with - | S.Fun (fid, _) -> Fun (FromLlbc (fid, None, Some back_id)) - | S.Unop _ | S.Binop _ -> raise (Failure "Unreachable") - in - (* Update the context and return *) - ({ ctx with calls; abstractions }, fun_id) +(* Some generic translation functions (we need to translate different "flavours" + of types: sty, forward types, backward types, etc.) *) +let rec translate_generic_args (translate_ty : 'r T.ty -> ty) + (generics : 'r T.generic_args) : generic_args = + (* We ignore the regions: if they didn't cause trouble for the symbolic execution, + then everything's fine *) + let types = List.map translate_ty generics.types in + let const_generics = generics.const_generics in + let trait_refs = + List.map (translate_trait_ref translate_ty) generics.trait_refs + in + { types; const_generics; trait_refs } + +and translate_trait_ref (translate_ty : 'r T.ty -> ty) (tr : 'r T.trait_ref) : + trait_ref = + let trait_id = translate_trait_instance_id translate_ty tr.trait_id in + let generics = translate_generic_args translate_ty tr.generics in + let trait_decl_ref = + translate_trait_decl_ref translate_ty tr.trait_decl_ref + in + { trait_id; generics; trait_decl_ref } + +and translate_trait_decl_ref (translate_ty : 'r T.ty -> ty) + (tr : 'r T.trait_decl_ref) : trait_decl_ref = + let decl_generics = translate_generic_args translate_ty tr.decl_generics in + { trait_decl_id = tr.trait_decl_id; decl_generics } + +and translate_trait_instance_id (translate_ty : 'r T.ty -> ty) + (id : 'r T.trait_instance_id) : trait_instance_id = + let translate_trait_instance_id = translate_trait_instance_id translate_ty in + match id with + | T.Self -> Self + | TraitImpl id -> TraitImpl id + | BuiltinOrAuto _ -> + (* We should have eliminated those in the prepasses *) + raise (Failure "Unreachable") + | Clause id -> Clause id + | ParentClause (inst_id, decl_id, clause_id) -> + let inst_id = translate_trait_instance_id inst_id in + ParentClause (inst_id, decl_id, clause_id) + | ItemClause (inst_id, decl_id, item_name, clause_id) -> + let inst_id = translate_trait_instance_id inst_id in + ItemClause (inst_id, decl_id, item_name, clause_id) + | TraitRef tr -> TraitRef (translate_trait_ref translate_ty tr) + | FnPointer _ -> raise (Failure "TODO") + | UnknownTrait s -> raise (Failure ("Unknown trait found: " ^ s)) let rec translate_sty (ty : T.sty) : ty = let translate = translate_sty in match ty with - | T.Adt (type_id, regions, tys, cgs) -> ( - (* Can't translate types with regions for now *) - assert (regions = []); - let tys = List.map translate tys in + | T.Adt (type_id, generics) -> ( + let generics = translate_sgeneric_args generics in match type_id with - | T.AdtId adt_id -> Adt (AdtId adt_id, tys, cgs) - | T.Tuple -> mk_simpl_tuple_ty tys + | T.AdtId adt_id -> Adt (AdtId adt_id, generics) + | T.Tuple -> + assert (generics.const_generics = []); + mk_simpl_tuple_ty generics.types | T.Assumed aty -> ( match aty with - | T.Vec -> Adt (Assumed Vec, tys, cgs) - | T.Option -> Adt (Assumed Option, tys, cgs) | T.Box -> ( (* Eliminate the boxes *) - match tys with + match generics.types with | [ ty ] -> ty | _ -> raise (Failure "Box/vec/option type with incorrect number of arguments") ) - | T.Array -> Adt (Assumed Array, tys, cgs) - | T.Slice -> Adt (Assumed Slice, tys, cgs) - | T.Str -> Adt (Assumed Str, tys, cgs) - | T.Range -> Adt (Assumed Range, tys, cgs))) + | T.Array -> Adt (Assumed Array, generics) + | T.Slice -> Adt (Assumed Slice, generics) + | T.Str -> Adt (Assumed Str, generics))) | TypeVar vid -> TypeVar vid | Literal ty -> Literal ty | Never -> raise (Failure "Unreachable") | Ref (_, rty, _) -> translate rty + | RawPtr (ty, rkind) -> + let mut = match rkind with Mut -> Mut | Shared -> Const in + let ty = translate ty in + let generics = { types = [ ty ]; const_generics = []; trait_refs = [] } in + Adt (Assumed (RawPtr mut), generics) + | TraitType (trait_ref, generics, type_name) -> + let trait_ref = translate_strait_ref trait_ref in + let generics = translate_sgeneric_args generics in + TraitType (trait_ref, generics, type_name) + | Arrow _ -> raise (Failure "TODO") + +and translate_sgeneric_args (generics : T.sgeneric_args) : generic_args = + translate_generic_args translate_sty generics + +and translate_strait_ref (tr : T.strait_ref) : trait_ref = + translate_trait_ref translate_sty tr + +and translate_strait_instance_id (id : T.strait_instance_id) : trait_instance_id + = + translate_trait_instance_id translate_sty id + +let translate_trait_clause (clause : T.trait_clause) : trait_clause = + let { T.clause_id; meta = _; trait_id; generics } = clause in + let generics = translate_sgeneric_args generics in + { clause_id; trait_id; generics } + +let translate_strait_type_constraint (ttc : T.strait_type_constraint) : + trait_type_constraint = + let { T.trait_ref; generics; type_name; ty } = ttc in + let trait_ref = translate_strait_ref trait_ref in + let generics = translate_sgeneric_args generics in + let ty = translate_sty ty in + { trait_ref; generics; type_name; ty } + +let translate_predicates (preds : T.predicates) : predicates = + let trait_type_constraints = + List.map translate_strait_type_constraint preds.trait_type_constraints + in + { trait_type_constraints } + +let translate_generic_params (generics : T.generic_params) : generic_params = + let { T.regions = _; types; const_generics; trait_clauses } = generics in + let trait_clauses = List.map translate_trait_clause trait_clauses in + { types; const_generics; trait_clauses } let translate_field (f : T.field) : field = let field_name = f.field_name in @@ -452,15 +495,16 @@ let translate_type_decl_kind (kind : T.type_decl_kind) : type_decl_kind = point of moving this definition for now. *) let translate_type_decl (def : T.type_decl) : type_decl = - (* Translate *) let def_id = def.T.def_id in let name = def.name in + let { T.regions; types; const_generics; trait_clauses } = def.generics in (* Can't translate types with regions for now *) - assert (def.region_params = []); - let type_params = def.type_params in - let const_generic_params = def.const_generic_params in + assert (regions = []); + let trait_clauses = List.map translate_trait_clause trait_clauses in + let generics = { types; const_generics; trait_clauses } in let kind = translate_type_decl_kind def.T.kind in - { def_id; name; type_params; const_generic_params; kind } + let preds = translate_predicates def.preds in + { def_id; name; generics; kind; preds } let translate_type_id (id : T.type_id) : type_id = match id with @@ -468,12 +512,9 @@ let translate_type_id (id : T.type_id) : type_id = | T.Assumed aty -> let aty = match aty with - | T.Vec -> Vec - | T.Option -> Option | T.Array -> Array | T.Slice -> Slice | T.Str -> Str - | T.Range -> Range | T.Box -> (* Boxes have to be eliminated: this type id shouldn't be translated *) @@ -488,28 +529,26 @@ let translate_type_id (id : T.type_id) : type_id = let rec translate_fwd_ty (type_infos : TA.type_infos) (ty : 'r T.ty) : ty = let translate = translate_fwd_ty type_infos in match ty with - | T.Adt (type_id, regions, tys, cgs) -> ( - (* Can't translate types with regions for now *) - assert (regions = []); - (* Translate the type parameters *) - let t_tys = List.map translate tys in + | T.Adt (type_id, generics) -> ( + let t_generics = translate_fwd_generic_args type_infos generics in (* Eliminate boxes and simplify tuples *) match type_id with - | AdtId _ - | T.Assumed (T.Vec | T.Option | T.Array | T.Slice | T.Str | T.Range) -> - (* No general parametricity for now *) - assert (not (List.exists (TypesUtils.ty_has_borrows type_infos) tys)); + | AdtId _ | T.Assumed (T.Array | T.Slice | T.Str) -> let type_id = translate_type_id type_id in - Adt (type_id, t_tys, cgs) + Adt (type_id, t_generics) | Tuple -> (* Note that if there is exactly one type, [mk_simpl_tuple_ty] is the identity *) - mk_simpl_tuple_ty t_tys + mk_simpl_tuple_ty t_generics.types | T.Assumed T.Box -> ( (* We eliminate boxes *) (* No general parametricity for now *) - assert (not (List.exists (TypesUtils.ty_has_borrows type_infos) tys)); - match t_tys with + assert ( + not + (List.exists + (TypesUtils.ty_has_borrows type_infos) + generics.types)); + match t_generics.types with | [ bty ] -> bty | _ -> raise @@ -520,12 +559,40 @@ let rec translate_fwd_ty (type_infos : TA.type_infos) (ty : 'r T.ty) : ty = | Never -> raise (Failure "Unreachable") | Literal lty -> Literal lty | Ref (_, rty, _) -> translate rty + | RawPtr (ty, rkind) -> + let mut = match rkind with Mut -> Mut | Shared -> Const in + let ty = translate ty in + let generics = { types = [ ty ]; const_generics = []; trait_refs = [] } in + Adt (Assumed (RawPtr mut), generics) + | TraitType (trait_ref, generics, type_name) -> + let trait_ref = translate_fwd_trait_ref type_infos trait_ref in + let generics = translate_fwd_generic_args type_infos generics in + TraitType (trait_ref, generics, type_name) + | Arrow _ -> raise (Failure "TODO") + +and translate_fwd_generic_args (type_infos : TA.type_infos) + (generics : 'r T.generic_args) : generic_args = + translate_generic_args (translate_fwd_ty type_infos) generics + +and translate_fwd_trait_ref (type_infos : TA.type_infos) (tr : 'r T.trait_ref) : + trait_ref = + translate_trait_ref (translate_fwd_ty type_infos) tr + +and translate_fwd_trait_instance_id (type_infos : TA.type_infos) + (id : 'r T.trait_instance_id) : trait_instance_id = + translate_trait_instance_id (translate_fwd_ty type_infos) id (** Simply calls [translate_fwd_ty] *) let ctx_translate_fwd_ty (ctx : bs_ctx) (ty : 'r T.ty) : ty = let type_infos = ctx.type_context.type_infos in translate_fwd_ty type_infos ty +(** Simply calls [translate_fwd_generic_args] *) +let ctx_translate_fwd_generic_args (ctx : bs_ctx) (generics : 'r T.generic_args) + : generic_args = + let type_infos = ctx.type_context.type_infos in + translate_fwd_generic_args type_infos generics + (** Translate a type, when some regions may have ended. We return an option, because the translated type may be empty. @@ -538,30 +605,40 @@ let rec translate_back_ty (type_infos : TA.type_infos) (* A small helper for "leave" types *) let wrap ty = if inside_mut then Some ty else None in match ty with - | T.Adt (type_id, _, tys, cgs) -> ( + | T.Adt (type_id, generics) -> ( match type_id with - | T.AdtId _ - | Assumed (T.Vec | T.Option | T.Array | T.Slice | T.Str | T.Range) -> - (* Don't accept ADTs (which are not tuples) with borrows for now *) - assert (not (TypesUtils.ty_has_borrows type_infos ty)); + | T.AdtId _ | Assumed (T.Array | T.Slice | T.Str) -> let type_id = translate_type_id type_id in if inside_mut then - let tys_t = List.filter_map translate tys in - Some (Adt (type_id, tys_t, cgs)) - else None + (* We do not want to filter anything, so we translate the generics + as "forward" types *) + let generics = translate_fwd_generic_args type_infos generics in + Some (Adt (type_id, generics)) + else + (* If not inside a mutable reference: check if at least one + of the generics contains a mutable reference (i.e., is not + translated to `None`. If yes, keep the whole type, and + translate all the generics as "forward" types (the backward + function will extract the proper information from the ADT value) + *) + let types = List.filter_map translate generics.types in + if types <> [] then + let generics = translate_fwd_generic_args type_infos generics in + Some (Adt (type_id, generics)) + else None | Assumed T.Box -> ( (* Don't accept ADTs (which are not tuples) with borrows for now *) assert (not (TypesUtils.ty_has_borrows type_infos ty)); (* Eliminate the box *) - match tys with + match generics.types with | [ bty ] -> translate bty | _ -> raise (Failure "Unreachable: boxes receive exactly one type parameter") ) | T.Tuple -> ( - (* Tuples can contain borrows (which we eliminated) *) - let tys_t = List.filter_map translate tys in + (* Tuples can contain borrows (which we eliminate) *) + let tys_t = List.filter_map translate generics.types in match tys_t with | [] -> None | _ -> @@ -582,6 +659,17 @@ let rec translate_back_ty (type_infos : TA.type_infos) if keep_region r then translate_back_ty type_infos keep_region inside_mut rty else None) + | RawPtr _ -> + (* TODO: not sure what to do here *) + None + | TraitType (trait_ref, generics, type_name) -> + assert (generics.regions = []); + (* Translate the trait ref and the generics as "forward" generics - + we do not want to filter any type *) + let trait_ref = translate_fwd_trait_ref type_infos trait_ref in + let generics = translate_fwd_generic_args type_infos generics in + Some (TraitType (trait_ref, generics, type_name)) + | Arrow _ -> raise (Failure "TODO") (** Simply calls [translate_back_ty] *) let ctx_translate_back_ty (ctx : bs_ctx) (keep_region : 'r -> bool) @@ -589,6 +677,80 @@ let ctx_translate_back_ty (ctx : bs_ctx) (keep_region : 'r -> bool) let type_infos = ctx.type_context.type_infos in translate_back_ty type_infos keep_region inside_mut ty +let mk_type_check_ctx (ctx : bs_ctx) : PureTypeCheck.tc_ctx = + let const_generics = + T.ConstGenericVarId.Map.of_list + (List.map + (fun (cg : T.const_generic_var) -> + (cg.index, ctx_translate_fwd_ty ctx (T.Literal cg.ty))) + ctx.sg.generics.const_generics) + in + let env = VarId.Map.empty in + { + PureTypeCheck.type_decls = ctx.type_context.type_decls; + global_decls = ctx.global_context.llbc_global_decls; + env; + const_generics; + } + +let type_check_pattern (ctx : bs_ctx) (v : typed_pattern) : unit = + let ctx = mk_type_check_ctx ctx in + let _ = PureTypeCheck.check_typed_pattern ctx v in + () + +let type_check_texpression (ctx : bs_ctx) (e : texpression) : unit = + if !Config.type_check_pure_code then + let ctx = mk_type_check_ctx ctx in + PureTypeCheck.check_texpression ctx e + +let translate_fun_id_or_trait_method_ref (ctx : bs_ctx) + (id : A.fun_id_or_trait_method_ref) : fun_id_or_trait_method_ref = + match id with + | FunId fun_id -> FunId fun_id + | TraitMethod (trait_ref, method_name, fun_decl_id) -> + let type_infos = ctx.type_context.type_infos in + let trait_ref = translate_fwd_trait_ref type_infos trait_ref in + TraitMethod (trait_ref, method_name, fun_decl_id) + +let bs_ctx_register_forward_call (call_id : V.FunCallId.id) (forward : S.call) + (args : texpression list) (ctx : bs_ctx) : bs_ctx = + let calls = ctx.calls in + assert (not (V.FunCallId.Map.mem call_id calls)); + let info = + { forward; forward_inputs = args; backwards = T.RegionGroupId.Map.empty } + in + let calls = V.FunCallId.Map.add call_id info calls in + { ctx with calls } + +(** [back_args]: the *additional* list of inputs received by the backward function *) +let bs_ctx_register_backward_call (abs : V.abs) (call_id : V.FunCallId.id) + (back_id : T.RegionGroupId.id) (back_args : texpression list) (ctx : bs_ctx) + : bs_ctx * fun_or_op_id = + (* Insert the abstraction in the call informations *) + let info = V.FunCallId.Map.find call_id ctx.calls in + assert (not (T.RegionGroupId.Map.mem back_id info.backwards)); + let backwards = + T.RegionGroupId.Map.add back_id (abs, back_args) info.backwards + in + let info = { info with backwards } in + let calls = V.FunCallId.Map.add call_id info ctx.calls in + (* Insert the abstraction in the abstractions map *) + let abstractions = ctx.abstractions in + assert (not (V.AbstractionId.Map.mem abs.abs_id abstractions)); + let abstractions = + V.AbstractionId.Map.add abs.abs_id (abs, back_args) abstractions + in + (* Retrieve the fun_id *) + let fun_id = + match info.forward.call_id with + | S.Fun (fid, _) -> + let fid = translate_fun_id_or_trait_method_ref ctx fid in + Fun (FromLlbc (fid, None, Some back_id)) + | S.Unop _ | S.Binop _ -> raise (Failure "Unreachable") + in + (* Update the context and return *) + ({ ctx with calls; abstractions }, fun_id) + (** List the ancestors of an abstraction *) let list_ancestor_abstractions_ids (ctx : bs_ctx) (abs : V.abs) (call_id : V.FunCallId.id) : V.AbstractionId.id list = @@ -642,10 +804,10 @@ let mk_fuel_input_as_list (ctx : bs_ctx) (info : fun_effect_info) : (** Small utility. *) let get_fun_effect_info (fun_infos : FA.fun_info A.FunDeclId.Map.t) - (fun_id : A.fun_id) (lid : V.LoopId.id option) + (fun_id : A.fun_id_or_trait_method_ref) (lid : V.LoopId.id option) (gid : T.RegionGroupId.id option) : fun_effect_info = match fun_id with - | A.Regular fid -> + | TraitMethod (_, _, fid) | FunId (Regular fid) -> let info = A.FunDeclId.Map.find fid fun_infos in let stateful_group = info.stateful in let stateful = @@ -658,10 +820,10 @@ let get_fun_effect_info (fun_infos : FA.fun_info A.FunDeclId.Map.t) can_diverge = info.can_diverge; is_rec = info.is_rec || Option.is_some lid; } - | A.Assumed aid -> + | FunId (Assumed aid) -> assert (lid = None); { - can_fail = Assumed.assumed_can_fail aid; + can_fail = Assumed.assumed_fun_can_fail aid; stateful_group = false; stateful = false; can_diverge = false; @@ -673,12 +835,14 @@ let get_fun_effect_info (fun_infos : FA.fun_info A.FunDeclId.Map.t) Note that the function also takes a list of names for the inputs, and computes, for every output for the backward functions, a corresponding name (outputs for backward functions come from borrows in the inputs - of the forward function) which we use as hints to generate pretty names. + of the forward function) which we use as hints to generate pretty names + in the extracted code. *) -let translate_fun_sig (fun_infos : FA.fun_info A.FunDeclId.Map.t) - (fun_id : A.fun_id) (type_infos : TA.type_infos) (sg : A.fun_sig) - (input_names : string option list) (bid : T.RegionGroupId.id option) : - fun_sig_named_outputs = +let translate_fun_sig (decls_ctx : C.decls_ctx) (fun_id : A.fun_id) + (sg : A.fun_sig) (input_names : string option list) + (bid : T.RegionGroupId.id option) : fun_sig_named_outputs = + let fun_infos = decls_ctx.fun_ctx.fun_infos in + let type_infos = decls_ctx.type_ctx.type_infos in (* Retrieve the list of parent backward functions *) let gid, parents = match bid with @@ -689,7 +853,34 @@ let translate_fun_sig (fun_infos : FA.fun_info A.FunDeclId.Map.t) in (* Is the function stateful, and can it fail? *) let lid = None in - let effect_info = get_fun_effect_info fun_infos fun_id lid bid in + let effect_info = get_fun_effect_info fun_infos (FunId fun_id) lid bid in + (* We need an evaluation context to normalize the types (to normalize the + associated types, etc. - for instance it may happen that the types + refer to the types associated to a trait ref, but where the trait ref + is a known impl). *) + (* Create the context *) + let ctx = + let region_groups = + List.map (fun (g : T.region_var_group) -> g.id) sg.regions_hierarchy + in + let ctx = + InterpreterUtils.initialize_eval_context decls_ctx region_groups + sg.generics.types sg.generics.const_generics + in + (* Compute the normalization map for the *sty* types and add it to the context *) + AssociatedTypes.ctx_add_norm_trait_stypes_from_preds ctx + sg.preds.trait_type_constraints + in + + (* Normalize the signature *) + let sg = + let ({ A.inputs; output; _ } : A.fun_sig) = sg in + let norm = AssociatedTypes.ctx_normalize_sty ctx in + let inputs = List.map norm inputs in + let output = norm output in + { sg with A.inputs; output } + in + (* List the inputs for: * - the fuel * - the forward function @@ -806,9 +997,8 @@ let translate_fun_sig (fun_infos : FA.fun_info A.FunDeclId.Map.t) (* Wrap in a result type *) if effect_info.can_fail then mk_result_ty output else output in - (* Type/const generic parameters *) - let type_params = sg.type_params in - let const_generic_params = sg.const_generic_params in + (* Generic parameters *) + let generics = translate_generic_params sg.generics in (* Return *) let has_fuel = fuel <> [] in let num_fwd_inputs_no_state = List.length fwd_inputs in @@ -836,9 +1026,8 @@ let translate_fun_sig (fun_infos : FA.fun_info A.FunDeclId.Map.t) effect_info; } in - let sg = - { type_params; const_generic_params; inputs; output; doutputs; info } - in + let preds = translate_predicates sg.A.preds in + let sg = { generics; preds; inputs; output; doutputs; info } in { sg; output_names } let bs_ctx_fresh_state_var (ctx : bs_ctx) : bs_ctx * typed_pattern = @@ -917,7 +1106,7 @@ let lookup_var_for_symbolic_value (sv : V.symbolic_value) (ctx : bs_ctx) : var = (** Peel boxes as long as the value is of the form [Box<T>] *) let rec unbox_typed_value (v : V.typed_value) : V.typed_value = match (v.value, v.ty) with - | V.Adt av, T.Adt (T.Assumed T.Box, _, _, _) -> ( + | V.Adt av, T.Adt (T.Assumed T.Box, _) -> ( match av.field_values with | [ bv ] -> unbox_typed_value bv | _ -> raise (Failure "Unreachable")) @@ -962,16 +1151,16 @@ let rec typed_value_to_texpression (ctx : bs_ctx) (ectx : C.eval_ctx) let field_values = List.map translate av.field_values in (* Eliminate the tuple wrapper if it is a tuple with exactly one field *) match v.ty with - | T.Adt (T.Tuple, _, _, _) -> + | T.Adt (T.Tuple, _) -> assert (variant_id = None); mk_simpl_tuple_texpression field_values | _ -> - (* Retrieve the type, the translated type arguments and the - * const generic arguments from the translated type (simpler this way) *) - let adt_id, type_args, const_generic_args = ty_as_adt ty in + (* Retrieve the type and the translated generics from the translated + type (simpler this way) *) + let adt_id, generics = ty_as_adt ty in (* Create the constructor *) let qualif_id = AdtCons { adt_id; variant_id = av.variant_id } in - let qualif = { id = qualif_id; type_args; const_generic_args } in + let qualif = { id = qualif_id; generics } in let cons_e = Qualif qualif in let field_tys = List.map (fun (v : texpression) -> v.ty) field_values @@ -1038,11 +1227,9 @@ let rec typed_avalue_to_consumed (ctx : bs_ctx) (ectx : C.eval_ctx) (* Translate the field values *) let field_values = List.filter_map translate adt_v.field_values in (* For now, only tuples can contain borrows *) - let adt_id, _, _, _ = TypesUtils.ty_as_adt av.ty in + let adt_id, _ = TypesUtils.ty_as_adt av.ty in match adt_id with - | T.AdtId _ - | T.Assumed - (T.Box | T.Vec | T.Option | T.Array | T.Slice | T.Str | T.Range) -> + | T.AdtId _ | T.Assumed (T.Box | T.Array | T.Slice | T.Str) -> assert (field_values = []); None | T.Tuple -> @@ -1185,11 +1372,9 @@ let rec typed_avalue_to_given_back (mp : mplace option) (av : V.typed_avalue) (* For now, only tuples can contain borrows - note that if we gave * something like a [&mut Vec] to a function, we give back the * vector value upon visiting the "abstraction borrow" node *) - let adt_id, _, _, _ = TypesUtils.ty_as_adt av.ty in + let adt_id, _ = TypesUtils.ty_as_adt av.ty in match adt_id with - | T.AdtId _ - | T.Assumed - (T.Box | T.Vec | T.Option | T.Array | T.Slice | T.Str | T.Range) -> + | T.AdtId _ | T.Assumed (T.Box | T.Array | T.Slice | T.Str) -> assert (field_values = []); (ctx, None) | T.Tuple -> @@ -1457,9 +1642,12 @@ and translate_return_with_loop (loop_id : V.LoopId.id) (is_continue : bool) and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) : texpression = + log#ldebug + (lazy + ("translate_function_call:\n" + ^ ctx_egeneric_args_to_string ctx call.generics)); (* Translate the function call *) - let type_args = List.map (ctx_translate_fwd_ty ctx) call.type_params in - let const_generic_args = call.const_generic_params in + let generics = ctx_translate_fwd_generic_args ctx call.generics in let args = let args = List.map (typed_value_to_texpression ctx call.ctx) call.args in let args_mplaces = List.map translate_opt_mplace call.args_places in @@ -1475,7 +1663,8 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) : match call.call_id with | S.Fun (fid, call_id) -> (* Regular function call *) - let func = Fun (FromLlbc (fid, None, None)) in + let fid_t = translate_fun_id_or_trait_method_ref ctx fid in + let func = Fun (FromLlbc (fid_t, None, None)) in (* Retrieve the effect information about this function (can fail, * takes a state as input, etc.) *) let effect_info = @@ -1525,18 +1714,20 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) : in (ctx, Unop (Neg int_ty), effect_info, args, None) | _ -> raise (Failure "Unreachable")) - | S.Unop (E.Cast (src_ty, tgt_ty)) -> - (* Note that cast can fail *) - let effect_info = - { - can_fail = true; - stateful_group = false; - stateful = false; - can_diverge = false; - is_rec = false; - } - in - (ctx, Unop (Cast (src_ty, tgt_ty)), effect_info, args, None) + | S.Unop (E.Cast cast_kind) -> ( + match cast_kind with + | CastInteger (src_ty, tgt_ty) -> + (* Note that cast can fail *) + let effect_info = + { + can_fail = true; + stateful_group = false; + stateful = false; + can_diverge = false; + is_rec = false; + } + in + (ctx, Unop (Cast (src_ty, tgt_ty)), effect_info, args, None)) | S.Binop binop -> ( match args with | [ arg0; arg1 ] -> @@ -1561,7 +1752,7 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) : | None -> dest | Some out_state -> mk_simpl_tuple_pattern [ out_state; dest ] in - let func = { id = FunOrOp fun_id; type_args; const_generic_args } in + let func = { id = FunOrOp fun_id; generics } in let input_tys = (List.map (fun (x : texpression) -> x.ty)) args in let ret_ty = if effect_info.can_fail then mk_result_ty dest_v.ty else dest_v.ty @@ -1665,9 +1856,11 @@ and translate_end_abstraction_synth_input (ectx : C.eval_ctx) (abs : V.abs) (* Group the two lists *) let variables_values = List.combine given_back_variables consumed_values in (* Sanity check: the two lists match (same types) *) - List.iter - (fun (var, v) -> assert ((var : var).ty = (v : texpression).ty)) - variables_values; + (* TODO: normalize the types *) + if !Config.type_check_pure_code then + List.iter + (fun (var, v) -> assert ((var : var).ty = (v : texpression).ty)) + variables_values; (* Translate the next expression *) let next_e = translate_expression e ctx in (* Generate the assignemnts *) @@ -1692,8 +1885,7 @@ and translate_end_abstraction_fun_call (ectx : C.eval_ctx) (abs : V.abs) let effect_info = get_fun_effect_info ctx.fun_context.fun_infos fun_id None (Some rg_id) in - let type_args = List.map (ctx_translate_fwd_ty ctx) call.type_params in - let const_generic_args = call.const_generic_params in + let generics = ctx_translate_fwd_generic_args ctx call.generics in (* Retrieve the original call and the parent abstractions *) let _forward, backwards = get_abs_ancestors ctx abs call_id in (* Retrieve the values consumed when we called the forward function and @@ -1741,34 +1933,35 @@ and translate_end_abstraction_fun_call (ectx : C.eval_ctx) (abs : V.abs) | Some nstate -> mk_simpl_tuple_pattern [ nstate; output ] in (* Sanity check: there is the proper number of inputs and outputs, and they have the proper type *) - let _ = - let inst_sg = - get_instantiated_fun_sig fun_id (Some rg_id) type_args const_generic_args - ctx - in - log#ldebug - (lazy - ("\n- fun_id: " ^ A.show_fun_id fun_id ^ "\n- inputs (" - ^ string_of_int (List.length inputs) - ^ "): " - ^ String.concat ", " (List.map (texpression_to_string ctx) inputs) - ^ "\n- inst_sg.inputs (" - ^ string_of_int (List.length inst_sg.inputs) - ^ "): " - ^ String.concat ", " (List.map (ty_to_string ctx) inst_sg.inputs))); - List.iter - (fun (x, ty) -> assert ((x : texpression).ty = ty)) - (List.combine inputs inst_sg.inputs); - log#ldebug - (lazy - ("\n- outputs: " - ^ string_of_int (List.length outputs) - ^ "\n- expected outputs: " - ^ string_of_int (List.length inst_sg.doutputs))); - List.iter - (fun (x, ty) -> assert ((x : typed_pattern).ty = ty)) - (List.combine outputs inst_sg.doutputs) - in + (if (* TODO: normalize the types *) !Config.type_check_pure_code then + match fun_id with + | FunId fun_id -> + let inst_sg = + get_instantiated_fun_sig fun_id (Some rg_id) generics ctx + in + log#ldebug + (lazy + ("\n- fun_id: " ^ A.show_fun_id fun_id ^ "\n- inputs (" + ^ string_of_int (List.length inputs) + ^ "): " + ^ String.concat ", " (List.map (texpression_to_string ctx) inputs) + ^ "\n- inst_sg.inputs (" + ^ string_of_int (List.length inst_sg.inputs) + ^ "): " + ^ String.concat ", " (List.map (ty_to_string ctx) inst_sg.inputs))); + List.iter + (fun (x, ty) -> assert ((x : texpression).ty = ty)) + (List.combine inputs inst_sg.inputs); + log#ldebug + (lazy + ("\n- outputs: " + ^ string_of_int (List.length outputs) + ^ "\n- expected outputs: " + ^ string_of_int (List.length inst_sg.doutputs))); + List.iter + (fun (x, ty) -> assert ((x : typed_pattern).ty = ty)) + (List.combine outputs inst_sg.doutputs) + | _ -> (* TODO: trait methods *) ()); (* Retrieve the function id, and register the function call in the context * if necessary *) let ctx, func = @@ -1788,7 +1981,7 @@ and translate_end_abstraction_fun_call (ectx : C.eval_ctx) (abs : V.abs) if effect_info.can_fail then mk_result_ty output.ty else output.ty in let func_ty = mk_arrows input_tys ret_ty in - let func = { id = FunOrOp func; type_args; const_generic_args } in + let func = { id = FunOrOp func; generics } in let func = { e = Qualif func; ty = func_ty } in let call = mk_apps func args in (* **Optimization**: @@ -1905,14 +2098,13 @@ and translate_end_abstraction_loop (ectx : C.eval_ctx) (abs : V.abs) (* Actually the same case as [SynthInput] *) translate_end_abstraction_synth_input ectx abs e ctx rg_id | V.LoopCall -> - let fun_id = A.Regular ctx.fun_decl.A.def_id in + let fun_id = E.Regular ctx.fun_decl.A.def_id in let effect_info = - get_fun_effect_info ctx.fun_context.fun_infos fun_id (Some vloop_id) - (Some rg_id) + get_fun_effect_info ctx.fun_context.fun_infos (FunId fun_id) + (Some vloop_id) (Some rg_id) in let loop_info = LoopId.Map.find loop_id ctx.loops in - let type_args = loop_info.type_args in - let const_generic_args = loop_info.const_generic_args in + let generics = loop_info.generics in let fwd_inputs = Option.get loop_info.forward_inputs in (* Retrieve the additional backward inputs. Note that those are actually the backward inputs of the function we are synthesizing (and that we @@ -1960,8 +2152,8 @@ and translate_end_abstraction_loop (ectx : C.eval_ctx) (abs : V.abs) if effect_info.can_fail then mk_result_ty output.ty else output.ty in let func_ty = mk_arrows input_tys ret_ty in - let func = Fun (FromLlbc (fun_id, Some loop_id, Some rg_id)) in - let func = { id = FunOrOp func; type_args; const_generic_args } in + let func = Fun (FromLlbc (FunId fun_id, Some loop_id, Some rg_id)) in + let func = { id = FunOrOp func; generics } in let func = { e = Qualif func; ty = func_ty } in let call = mk_apps func args in (* **Optimization**: @@ -2021,9 +2213,7 @@ and translate_global_eval (gid : A.GlobalDeclId.id) (sval : V.symbolic_value) (e : S.expression) (ctx : bs_ctx) : texpression = let ctx, var = fresh_var_for_symbolic_value sval ctx in let decl = A.GlobalDeclId.Map.find gid ctx.global_context.llbc_global_decls in - let global_expr = - { id = Global gid; type_args = []; const_generic_args = [] } - in + let global_expr = { id = Global gid; generics = empty_generic_args } in (* We use translate_fwd_ty to translate the global type *) let ty = ctx_translate_fwd_ty ctx decl.ty in let gval = { e = Qualif global_expr; ty } in @@ -2037,11 +2227,7 @@ and translate_assertion (ectx : C.eval_ctx) (v : V.typed_value) let v = typed_value_to_texpression ctx ectx v in let args = [ v ] in let func = - { - id = FunOrOp (Fun (Pure Assert)); - type_args = []; - const_generic_args = []; - } + { id = FunOrOp (Fun (Pure Assert)); generics = empty_generic_args } in let func_ty = mk_arrow (Literal Bool) mk_unit_ty in let func = { e = Qualif func; ty = func_ty } in @@ -2189,7 +2375,7 @@ and translate_ExpandAdt_one_branch (sv : V.symbolic_value) (branch : S.expression) (ctx : bs_ctx) : texpression = (* TODO: always introduce a match, and use micro-passes to turn the the match into a let? *) - let type_id, _, _, _ = TypesUtils.ty_as_adt sv.V.sv_ty in + let type_id, _ = TypesUtils.ty_as_adt sv.V.sv_ty in let ctx, vars = fresh_vars_for_symbolic_values svl ctx in let branch = translate_expression branch ctx in match type_id with @@ -2224,10 +2410,10 @@ and translate_ExpandAdt_one_branch (sv : V.symbolic_value) * field. * We use the [dest] variable in order not to have to recompute * the type of the result of the projection... *) - let adt_id, type_args, const_generic_args = ty_as_adt scrutinee.ty in + let adt_id, generics = ty_as_adt scrutinee.ty in let gen_field_proj (field_id : FieldId.id) (dest : var) : texpression = let proj_kind = { adt_id; field_id } in - let qualif = { id = Proj proj_kind; type_args; const_generic_args } in + let qualif = { id = Proj proj_kind; generics } in let proj_e = Qualif qualif in let proj_ty = mk_arrow scrutinee.ty dest.ty in let proj = { e = proj_e; ty = proj_ty } in @@ -2259,17 +2445,12 @@ and translate_ExpandAdt_one_branch (sv : V.symbolic_value) (mk_typed_pattern_from_var var None) (mk_opt_mplace_texpression scrutinee_mplace scrutinee) branch - | T.Assumed (T.Vec | T.Array | T.Slice | T.Str) -> + | T.Assumed (T.Array | T.Slice | T.Str) -> (* We can't expand those values: we can access the fields only * through the functions provided by the API (note that we don't * know how to expand values like vectors or arrays, because they have a variable number * of fields!) *) raise (Failure "Attempt to expand a non-expandable value") - | T.Assumed Range -> raise (Failure "Unimplemented") - | T.Assumed T.Option -> - (* We shouldn't get there in the "one-branch" case: options have - * two variants *) - raise (Failure "Unreachable") and translate_intro_symbolic (ectx : C.eval_ctx) (p : S.mplace option) (sv : V.symbolic_value) (v : S.value_aggregate) (e : S.expression) @@ -2282,8 +2463,9 @@ and translate_intro_symbolic (ectx : C.eval_ctx) (p : S.mplace option) (* Translate the next expression *) let next_e = translate_expression e ctx in - (* Translate the value: there are two cases, depending on whether this - is a "regular" let-binding or an array aggregate. + (* Translate the value: there are several cases, depending on whether this + is a "regular" let-binding, an array aggregate, a const generic or + a trait associated constant. *) let v = match v with @@ -2298,6 +2480,14 @@ and translate_intro_symbolic (ectx : C.eval_ctx) (p : S.mplace option) { struct_id = Assumed Array; init = None; updates = values } in { e = StructUpdate su; ty = var.ty } + | ConstGenericValue cg_id -> { e = CVar cg_id; ty = var.ty } + | TraitConstValue (trait_ref, generics, const_name) -> + let type_infos = ctx.type_context.type_infos in + let trait_ref = translate_fwd_trait_ref type_infos trait_ref in + let generics = translate_fwd_generic_args type_infos generics in + let qualif_id = TraitConst (trait_ref, generics, const_name) in + let qualif = { id = qualif_id; generics = empty_generic_args } in + { e = Qualif qualif; ty = var.ty } in (* Make the let-binding *) @@ -2368,9 +2558,9 @@ and translate_forward_end (ectx : C.eval_ctx) let org_args = args in (* Lookup the effect info for the loop function *) - let fid = A.Regular ctx.fun_decl.A.def_id in + let fid = E.Regular ctx.fun_decl.A.def_id in let effect_info = - get_fun_effect_info ctx.fun_context.fun_infos fid None ctx.bid + get_fun_effect_info ctx.fun_context.fun_infos (FunId fid) None ctx.bid in (* Introduce a fresh output value for the forward function *) @@ -2415,14 +2605,8 @@ and translate_forward_end (ectx : C.eval_ctx) let out_pat = mk_simpl_tuple_pattern out_pats in let loop_call = - let fun_id = Fun (FromLlbc (fid, Some loop_id, None)) in - let func = - { - id = FunOrOp fun_id; - type_args = loop_info.type_args; - const_generic_args = loop_info.const_generic_args; - } - in + let fun_id = Fun (FromLlbc (FunId fid, Some loop_id, None)) in + let func = { id = FunOrOp fun_id; generics = loop_info.generics } in let input_tys = (List.map (fun (x : texpression) -> x.ty)) args in let ret_ty = if effect_info.can_fail then mk_result_ty out_pat.ty else out_pat.ty @@ -2541,14 +2725,31 @@ and translate_loop (loop : S.loop) (ctx : bs_ctx) : texpression = (* Note that we will retrieve the input values later in the [ForwardEnd] (and will introduce the outputs at that moment, together with the actual - call to the loop forward function *) - let type_args = - List.map (fun (ty : T.type_var) -> TypeVar ty.T.index) ctx.sg.type_params - in - let const_generic_args = - List.map - (fun (cg : T.const_generic_var) -> T.ConstGenericVar cg.T.index) - ctx.sg.const_generic_params + call to the loop forward function) *) + let generics = + let { types; const_generics; trait_clauses } = ctx.sg.generics in + let types = + List.map (fun (ty : T.type_var) -> TypeVar ty.T.index) types + in + let const_generics = + List.map + (fun (cg : T.const_generic_var) -> T.ConstGenericVar cg.T.index) + const_generics + in + let trait_refs = + List.map + (fun (c : trait_clause) -> + let trait_decl_ref = + { trait_decl_id = c.trait_id; decl_generics = empty_generic_args } + in + { + trait_id = Clause c.clause_id; + generics = empty_generic_args; + trait_decl_ref; + }) + trait_clauses + in + { types; const_generics; trait_refs } in let loop_info = @@ -2556,8 +2757,7 @@ and translate_loop (loop : S.loop) (ctx : bs_ctx) : texpression = loop_id; input_vars = inputs; input_svl = loop.input_svalues; - type_args; - const_generic_args; + generics; forward_inputs = None; forward_output_no_state_no_result = None; } @@ -2648,8 +2848,7 @@ let wrap_in_match_fuel (fuel0 : VarId.id) (fuel : VarId.id) (body : texpression) let func = { id = FunOrOp (Fun (Pure FuelEqZero)); - type_args = []; - const_generic_args = []; + generics = empty_generic_args; } in let func_ty = mk_arrow mk_fuel_ty mk_bool_ty in @@ -2661,8 +2860,7 @@ let wrap_in_match_fuel (fuel0 : VarId.id) (fuel : VarId.id) (body : texpression) let func = { id = FunOrOp (Fun (Pure FuelDecrease)); - type_args = []; - const_generic_args = []; + generics = empty_generic_args; } in let func_ty = mk_arrow mk_fuel_ty mk_fuel_ty in @@ -2727,8 +2925,8 @@ let translate_fun_decl (ctx : bs_ctx) (body : S.expression option) : fun_decl = | None -> None | Some body -> let effect_info = - get_fun_effect_info ctx.fun_context.fun_infos (Regular def_id) None - bid + get_fun_effect_info ctx.fun_context.fun_infos (FunId (Regular def_id)) + None bid in let body = translate_expression body ctx in (* Add a match over the fuel, if necessary *) @@ -2803,10 +3001,12 @@ let translate_fun_decl (ctx : bs_ctx) (body : S.expression option) : fun_decl = ^ "\n- signature.inputs: " ^ String.concat ", " (List.map (ty_to_string ctx) signature.inputs) )); - assert ( - List.for_all - (fun (var, ty) -> (var : var).ty = ty) - (List.combine inputs signature.inputs)); + (* TODO: we need to normalize the types *) + if !Config.type_check_pure_code then + assert ( + List.for_all + (fun (var, ty) -> (var : var).ty = ty) + (List.combine inputs signature.inputs)); Some { inputs; inputs_lvs; body } in @@ -2821,6 +3021,7 @@ let translate_fun_decl (ctx : bs_ctx) (body : S.expression option) : fun_decl = let def = { def_id; + kind = def.kind; num_loops; loop_id; back_id = bid; @@ -2853,8 +3054,7 @@ let translate_type_decls (type_decls : T.type_decl list) : type_decl list = - optional names for the outputs values (we derive them for the backward functions) *) -let translate_fun_signatures (fun_infos : FA.fun_info A.FunDeclId.Map.t) - (type_infos : TA.type_infos) +let translate_fun_signatures (decls_ctx : C.decls_ctx) (functions : (A.fun_id * string option list * A.fun_sig) list) : fun_sig_named_outputs RegularFunIdNotLoopMap.t = (* For every function, translate the signatures of: @@ -2865,17 +3065,14 @@ let translate_fun_signatures (fun_infos : FA.fun_info A.FunDeclId.Map.t) (sg : A.fun_sig) : (regular_fun_id_not_loop * fun_sig_named_outputs) list = (* The forward function *) - let fwd_sg = - translate_fun_sig fun_infos fun_id type_infos sg input_names None - in + let fwd_sg = translate_fun_sig decls_ctx fun_id sg input_names None in let fwd_id = (fun_id, None) in (* The backward functions *) let back_sgs = List.map (fun (rg : T.region_var_group) -> let tsg = - translate_fun_sig fun_infos fun_id type_infos sg input_names - (Some rg.id) + translate_fun_sig decls_ctx fun_id sg input_names (Some rg.id) in let id = (fun_id, Some rg.id) in (id, tsg)) @@ -2891,3 +3088,94 @@ let translate_fun_signatures (fun_infos : FA.fun_info A.FunDeclId.Map.t) List.fold_left (fun m (id, sg) -> RegularFunIdNotLoopMap.add id sg m) RegularFunIdNotLoopMap.empty translated + +let translate_trait_decl (type_infos : TA.type_infos) + (trait_decl : A.trait_decl) : trait_decl = + let { + def_id; + name; + generics; + preds; + parent_clauses; + consts; + types; + required_methods; + provided_methods; + } : A.trait_decl = + trait_decl + in + let generics = translate_generic_params generics in + let preds = translate_predicates preds in + let parent_clauses = List.map translate_trait_clause parent_clauses in + let consts = + List.map + (fun (name, (ty, id)) -> (name, (translate_fwd_ty type_infos ty, id))) + consts + in + let types = + List.map + (fun (name, (trait_clauses, ty)) -> + ( name, + ( List.map translate_trait_clause trait_clauses, + Option.map (translate_fwd_ty type_infos) ty ) )) + types + in + { + def_id; + name; + generics; + preds; + parent_clauses; + consts; + types; + required_methods; + provided_methods; + } + +let translate_trait_impl (type_infos : TA.type_infos) + (trait_impl : A.trait_impl) : trait_impl = + let { + A.def_id; + name; + impl_trait; + generics; + preds; + parent_trait_refs; + consts; + types; + required_methods; + provided_methods; + } = + trait_impl + in + let impl_trait = + translate_trait_decl_ref (translate_fwd_ty type_infos) impl_trait + in + let generics = translate_generic_params generics in + let preds = translate_predicates preds in + let parent_trait_refs = List.map translate_strait_ref parent_trait_refs in + let consts = + List.map + (fun (name, (ty, id)) -> (name, (translate_fwd_ty type_infos ty, id))) + consts + in + let types = + List.map + (fun (name, (trait_refs, ty)) -> + ( name, + ( List.map (translate_fwd_trait_ref type_infos) trait_refs, + translate_fwd_ty type_infos ty ) )) + types + in + { + def_id; + name; + impl_trait; + generics; + preds; + parent_trait_refs; + consts; + types; + required_methods; + provided_methods; + } diff --git a/compiler/SynthesizeSymbolic.ml b/compiler/SynthesizeSymbolic.ml index 857fea97..9dd65c84 100644 --- a/compiler/SynthesizeSymbolic.ml +++ b/compiler/SynthesizeSymbolic.ml @@ -64,7 +64,7 @@ let synthesize_symbolic_expansion (sv : V.symbolic_value) assert (otherwise_see = None); (* Return *) ExpandInt (int_ty, branches, otherwise) - | T.Adt (_, _, _, _) -> + | T.Adt (_, _) -> (* Branching: it is necessarily an enumeration expansion *) let get_variant (see : V.symbolic_expansion option) : T.VariantId.id option * V.symbolic_value list = @@ -85,7 +85,9 @@ let synthesize_symbolic_expansion (sv : V.symbolic_value) match ls with | [ (Some see, exp) ] -> ExpandNoBranch (see, exp) | _ -> raise (Failure "Ill-formed borrow expansion")) - | T.TypeVar _ | T.Literal Char | Never -> + | T.TypeVar _ + | T.Literal Char + | Never | T.TraitType _ | T.Arrow _ | T.RawPtr _ -> raise (Failure "Ill-formed symbolic expansion") in Some (Expansion (place, sv, expansion)) @@ -97,10 +99,10 @@ let synthesize_symbolic_expansion_no_branching (sv : V.symbolic_value) synthesize_symbolic_expansion sv place [ Some see ] el let synthesize_function_call (call_id : call_id) (ctx : Contexts.eval_ctx) - (abstractions : V.AbstractionId.id list) (type_params : T.ety list) - (const_generic_params : T.const_generic list) (args : V.typed_value list) - (args_places : mplace option list) (dest : V.symbolic_value) - (dest_place : mplace option) (e : expression option) : expression option = + (abstractions : V.AbstractionId.id list) (generics : T.egeneric_args) + (args : V.typed_value list) (args_places : mplace option list) + (dest : V.symbolic_value) (dest_place : mplace option) + (e : expression option) : expression option = Option.map (fun e -> let call = @@ -108,8 +110,7 @@ let synthesize_function_call (call_id : call_id) (ctx : Contexts.eval_ctx) call_id; ctx; abstractions; - type_params; - const_generic_params; + generics; args; dest; args_places; @@ -123,28 +124,29 @@ let synthesize_global_eval (gid : A.GlobalDeclId.id) (dest : V.symbolic_value) (e : expression option) : expression option = Option.map (fun e -> EvalGlobal (gid, dest, e)) e -let synthesize_regular_function_call (fun_id : A.fun_id) +let synthesize_regular_function_call (fun_id : A.fun_id_or_trait_method_ref) (call_id : V.FunCallId.id) (ctx : Contexts.eval_ctx) - (abstractions : V.AbstractionId.id list) (type_params : T.ety list) - (const_generic_params : T.const_generic list) (args : V.typed_value list) - (args_places : mplace option list) (dest : V.symbolic_value) - (dest_place : mplace option) (e : expression option) : expression option = + (abstractions : V.AbstractionId.id list) (generics : T.egeneric_args) + (args : V.typed_value list) (args_places : mplace option list) + (dest : V.symbolic_value) (dest_place : mplace option) + (e : expression option) : expression option = synthesize_function_call (Fun (fun_id, call_id)) - ctx abstractions type_params const_generic_params args args_places dest - dest_place e + ctx abstractions generics args args_places dest dest_place e let synthesize_unary_op (ctx : Contexts.eval_ctx) (unop : E.unop) (arg : V.typed_value) (arg_place : mplace option) (dest : V.symbolic_value) (dest_place : mplace option) (e : expression option) : expression option = - synthesize_function_call (Unop unop) ctx [] [] [] [ arg ] [ arg_place ] dest - dest_place e + let generics = TypesUtils.mk_empty_generic_args in + synthesize_function_call (Unop unop) ctx [] generics [ arg ] [ arg_place ] + dest dest_place e let synthesize_binary_op (ctx : Contexts.eval_ctx) (binop : E.binop) (arg0 : V.typed_value) (arg0_place : mplace option) (arg1 : V.typed_value) (arg1_place : mplace option) (dest : V.symbolic_value) (dest_place : mplace option) (e : expression option) : expression option = - synthesize_function_call (Binop binop) ctx [] [] [] [ arg0; arg1 ] + let generics = TypesUtils.mk_empty_generic_args in + synthesize_function_call (Binop binop) ctx [] generics [ arg0; arg1 ] [ arg0_place; arg1_place ] dest dest_place e let synthesize_end_abstraction (ctx : Contexts.eval_ctx) (abs : V.abs) diff --git a/compiler/Translate.ml b/compiler/Translate.ml index 70ef5e3d..a3d96023 100644 --- a/compiler/Translate.ml +++ b/compiler/Translate.ml @@ -5,6 +5,7 @@ module T = Types module A = LlbcAst module SA = SymbolicAst module Micro = PureMicroPasses +module C = Contexts open PureUtils open TranslateCore @@ -28,18 +29,12 @@ let translate_function_to_symbolics (trans_ctx : trans_ctx) (fdef : A.fun_decl) ("translate_function_to_symbolics: " ^ Print.fun_name_to_string fdef.A.name)); - let { type_context; fun_context; global_context } = trans_ctx in - let fun_context = { C.fun_decls = fun_context.fun_decls } in - match fdef.body with | None -> None | Some _ -> (* Evaluate *) let synthesize = true in - let inputs, symb = - evaluate_function_symbolic synthesize type_context fun_context - global_context fdef - in + let inputs, symb = evaluate_function_symbolic synthesize trans_ctx fdef in Some (inputs, Option.get symb) (** Translate a function, by generating its forward and backward translations. @@ -57,7 +52,6 @@ let translate_function_to_pure (trans_ctx : trans_ctx) (lazy ("translate_function_to_pure: " ^ Print.fun_name_to_string fdef.A.name)); - let { type_context; fun_context; global_context } = trans_ctx in let def_id = fdef.def_id in (* Compute the symbolic ASTs, if the function is transparent *) @@ -67,7 +61,7 @@ let translate_function_to_pure (trans_ctx : trans_ctx) (* Initialize the context *) let forward_sig = - RegularFunIdNotLoopMap.find (A.Regular def_id, None) fun_sigs + RegularFunIdNotLoopMap.find (E.Regular def_id, None) fun_sigs in let sv_to_var = V.SymbolicValueId.Map.empty in let var_counter = Pure.VarId.generator_zero in @@ -82,25 +76,25 @@ let translate_function_to_pure (trans_ctx : trans_ctx) (List.filter_map (fun (tid, g) -> match g with Charon.GAst.NonRec _ -> None | Rec _ -> Some tid) - (T.TypeDeclId.Map.bindings trans_ctx.type_context.type_decls_groups)) + (T.TypeDeclId.Map.bindings trans_ctx.type_ctx.type_decls_groups)) in let type_context = { - SymbolicToPure.type_infos = type_context.type_infos; - llbc_type_decls = type_context.type_decls; + SymbolicToPure.type_infos = trans_ctx.type_ctx.type_infos; + llbc_type_decls = trans_ctx.type_ctx.type_decls; type_decls = pure_type_decls; recursive_decls = recursive_type_decls; } in let fun_context = { - SymbolicToPure.llbc_fun_decls = fun_context.fun_decls; + SymbolicToPure.llbc_fun_decls = trans_ctx.fun_ctx.fun_decls; fun_sigs; - fun_infos = fun_context.fun_infos; + fun_infos = trans_ctx.fun_ctx.fun_infos; } in let global_context = - { SymbolicToPure.llbc_global_decls = global_context.global_decls } + { SymbolicToPure.llbc_global_decls = trans_ctx.global_ctx.global_decls } in (* Compute the set of loops, and find better ids for them (starting at 0). @@ -148,6 +142,8 @@ let translate_function_to_pure (trans_ctx : trans_ctx) type_context; fun_context; global_context; + trait_decls_ctx = trans_ctx.trait_decls_ctx.trait_decls; + trait_impls_ctx = trans_ctx.trait_impls_ctx.trait_impls; fun_decl = fdef; forward_inputs = []; (* Empty for now *) @@ -204,7 +200,7 @@ let translate_function_to_pure (trans_ctx : trans_ctx) (* Initialize the context - note that the ret_ty is not really * useful as we don't translate a body *) let backward_sg = - RegularFunIdNotLoopMap.find (A.Regular def_id, Some back_id) fun_sigs + RegularFunIdNotLoopMap.find (Regular def_id, Some back_id) fun_sigs in let ctx = { ctx with bid = Some back_id; sg = backward_sg.sg } in @@ -215,7 +211,7 @@ let translate_function_to_pure (trans_ctx : trans_ctx) variables required by the backward function. *) let backward_sg = - RegularFunIdNotLoopMap.find (A.Regular def_id, Some back_id) fun_sigs + RegularFunIdNotLoopMap.find (Regular def_id, Some back_id) fun_sigs in (* We need to ignore the forward inputs, and the state input (if there is) *) let backward_inputs = @@ -274,21 +270,18 @@ let translate_function_to_pure (trans_ctx : trans_ctx) (* Return *) (pure_forward, pure_backwards) +(* TODO: factor out the return type *) let translate_crate_to_pure (crate : A.crate) : - trans_ctx * Pure.type_decl list * (bool * pure_fun_translation) list = + trans_ctx + * Pure.type_decl list + * pure_fun_translation list + * Pure.trait_decl list + * Pure.trait_impl list = (* Debug *) log#ldebug (lazy "translate_crate_to_pure"); - (* Compute the type and function contexts *) - let type_context, fun_context, global_context = - compute_type_fun_global_contexts crate - in - let fun_infos = - FA.analyze_module crate fun_context.C.fun_decls - global_context.C.global_decls !Config.use_state - in - let fun_context = { fun_decls = fun_context.fun_decls; fun_infos } in - let trans_ctx = { type_context; fun_context; global_context } in + (* Compute the translation context *) + let trans_ctx = compute_contexts crate in (* Translate all the type definitions *) let type_decls = @@ -304,9 +297,11 @@ let translate_crate_to_pure (crate : A.crate) : (* Translate all the function *signatures* *) let assumed_sigs = List.map - (fun (id, sg, _, _) -> - (A.Assumed id, List.map (fun _ -> None) (sg : A.fun_sig).inputs, sg)) - Assumed.assumed_infos + (fun (info : Assumed.assumed_fun_info) -> + ( E.Assumed info.fun_id, + List.map (fun _ -> None) info.fun_sig.inputs, + info.fun_sig )) + Assumed.assumed_fun_infos in let local_sigs = List.map @@ -319,14 +314,11 @@ let translate_crate_to_pure (crate : A.crate) : (fun (v : A.var) -> v.name) (LlbcAstUtils.fun_body_get_input_vars body) in - (A.Regular fdef.def_id, input_names, fdef.signature)) + (E.Regular fdef.def_id, input_names, fdef.signature)) (A.FunDeclId.Map.values crate.functions) in let sigs = List.append assumed_sigs local_sigs in - let fun_sigs = - SymbolicToPure.translate_fun_signatures fun_context.fun_infos - type_context.type_infos sigs - in + let fun_sigs = SymbolicToPure.translate_fun_signatures trans_ctx sigs in (* Translate all the *transparent* functions *) let pure_translations = @@ -335,28 +327,38 @@ let translate_crate_to_pure (crate : A.crate) : (A.FunDeclId.Map.values crate.functions) in + (* Translate the trait declarations *) + let type_infos = trans_ctx.type_ctx.type_infos in + let trait_decls = + List.map + (SymbolicToPure.translate_trait_decl type_infos) + (T.TraitDeclId.Map.values trans_ctx.trait_decls_ctx.trait_decls) + in + + (* Translate the trait implementations *) + let trait_impls = + List.map + (SymbolicToPure.translate_trait_impl type_infos) + (T.TraitImplId.Map.values trans_ctx.trait_impls_ctx.trait_impls) + in + (* Apply the micro-passes *) let pure_translations = Micro.apply_passes_to_pure_fun_translations trans_ctx pure_translations in (* Return *) - (trans_ctx, type_decls, pure_translations) - -(** Extraction context *) -type gen_ctx = { - crate : A.crate; - extract_ctx : ExtractBase.extraction_ctx; - trans_types : Pure.type_decl Pure.TypeDeclId.Map.t; - trans_funs : (bool * pure_fun_translation) A.FunDeclId.Map.t; - functions_with_decreases_clause : PureUtils.FunLoopIdSet.t; -} + (trans_ctx, type_decls, pure_translations, trait_decls, trait_impls) + +type gen_ctx = ExtractBase.extraction_ctx type gen_config = { extract_types : bool; extract_decreases_clauses : bool; extract_template_decreases_clauses : bool; extract_fun_decls : bool; + extract_trait_decls : bool; + extract_trait_impls : bool; extract_transparent : bool; (** If [true], extract the transparent declarations, otherwise ignore. *) extract_opaque : bool; @@ -383,21 +385,23 @@ type gen_config = { test_trans_unit_functions : bool; } -(** Returns the pair: (has opaque type decls, has opaque fun decls) *) -let module_has_opaque_decls (ctx : gen_ctx) : bool * bool = - let has_opaque_types = - Pure.TypeDeclId.Map.exists - (fun _ (d : Pure.type_decl) -> - match d.kind with Opaque -> true | _ -> false) - ctx.trans_types - in - let has_opaque_funs = - A.FunDeclId.Map.exists - (fun _ ((_, ((t_fwd, _), _)) : bool * pure_fun_translation) -> - Option.is_none t_fwd.body) - ctx.trans_funs +(** Returns the pair: (has opaque type decls, has opaque fun decls). + + [filter_assumed]: if [true], do not consider as opaque the external definitions + that we will map to definitions from the standard library. + *) +let crate_has_opaque_non_builtin_decls (ctx : gen_ctx) (filter_assumed : bool) : + bool * bool = + let types, funs = + LlbcAstUtils.crate_get_opaque_non_builtin_decls ctx.crate filter_assumed in - (has_opaque_types, has_opaque_funs) + log#ldebug + (lazy + ("Opaque decls:" ^ "\n- types:\n" + ^ String.concat ",\n" (List.map T.show_type_decl types) + ^ "\n- functions:\n" + ^ String.concat ",\n" (List.map A.show_fun_decl funs))); + (types <> [], funs <> []) (** Export a type declaration. @@ -423,15 +427,19 @@ let export_type (fmt : Format.formatter) (config : gen_config) (ctx : gen_ctx) (true, kind) in (* Extract, if the config instructs to do so (depending on whether the type - * is opaque or not) *) - if + is opaque or not). Remark: we don't check if the definitions are builtin + here but in the function [export_types_group]: the reason is that if one + definition in the group is builtin, then we must check that all the + definitions are marked builtin *) + let extract = (is_opaque && config.extract_opaque) || ((not is_opaque) && config.extract_transparent) - then ( + in + if extract then ( if extract_decl then - Extract.extract_type_decl ctx.extract_ctx fmt type_decl_group kind def; + Extract.extract_type_decl ctx fmt type_decl_group kind def; if extract_extra_info then - Extract.extract_type_decl_extra_info ctx.extract_ctx fmt kind def) + Extract.extract_type_decl_extra_info ctx fmt kind def) (** Export a group of types. @@ -462,41 +470,58 @@ let export_types_group (fmt : Format.formatter) (config : gen_config) List.map (fun id -> Pure.TypeDeclId.Map.find id ctx.trans_types) ids in - (* Extract the type declarations. - - Because some declaration groups are delimited, we wrap the declarations - between [{start,end}_type_decl_group]. + (* Check if the definition are builtin - if yes they must be ignored. + Note that if one definition in the group is builtin, then all the + definitions must be builtin *) + let builtin = + let open ExtractBuiltin in + let types_map = builtin_types_map () in + List.map + (fun (def : Pure.type_decl) -> + let sname = name_to_simple_name def.name in + SimpleNameMap.find_opt sname types_map <> None) + defs + in - Ex.: - ==== - When targeting HOL4, the calls to [{start,end}_type_decl_group] would generate - the [Datatype] and [End] delimiters in the snippet of code below: + if List.exists (fun b -> b) builtin then + (* Sanity check *) + assert (List.for_all (fun b -> b) builtin) + else ( + (* Extract the type declarations. + + Because some declaration groups are delimited, we wrap the declarations + between [{start,end}_type_decl_group]. + + Ex.: + ==== + When targeting HOL4, the calls to [{start,end}_type_decl_group] would generate + the [Datatype] and [End] delimiters in the snippet of code below: + + {[ + Datatype: + tree = + TLeaf 'a + | TNode node ; + + node = + Node (tree list) + End + ]} + *) + Extract.start_type_decl_group ctx fmt is_rec defs; + List.iteri + (fun i def -> + let kind = kind_from_index i in + export_type_decl kind def) + defs; + Extract.end_type_decl_group fmt is_rec defs; - {[ - Datatype: - tree = - TLeaf 'a - | TNode node ; - - node = - Node (tree list) - End - ]} - *) - Extract.start_type_decl_group ctx.extract_ctx fmt is_rec defs; - List.iteri - (fun i def -> - let kind = kind_from_index i in - export_type_decl kind def) - defs; - Extract.end_type_decl_group fmt is_rec defs; - - (* Export the extra information (ex.: [Arguments] instructions in Coq) *) - List.iteri - (fun i def -> - let kind = kind_from_index i in - export_type_extra_info kind def) - defs + (* Export the extra information (ex.: [Arguments] instructions in Coq) *) + List.iteri + (fun i def -> + let kind = kind_from_index i in + export_type_extra_info kind def) + defs) (** Export a global declaration. @@ -504,26 +529,34 @@ let export_types_group (fmt : Format.formatter) (config : gen_config) *) let export_global (fmt : Format.formatter) (config : gen_config) (ctx : gen_ctx) (id : A.GlobalDeclId.id) : unit = - let global_decls = ctx.extract_ctx.trans_ctx.global_context.global_decls in + let global_decls = ctx.trans_ctx.global_ctx.global_decls in let global = A.GlobalDeclId.Map.find id global_decls in - let _, ((body, loop_fwds), body_backs) = - A.FunDeclId.Map.find global.body_id ctx.trans_funs - in - assert (body_backs = []); - assert (loop_fwds = []); + let trans = A.FunDeclId.Map.find global.body_id ctx.trans_funs in + assert (trans.fwd.loops = []); + assert (trans.backs = []); + let body = trans.fwd.f in let is_opaque = Option.is_none body.Pure.body in - if + (* Check if we extract the global *) + let extract = config.extract_globals && (((not is_opaque) && config.extract_transparent) || (is_opaque && config.extract_opaque)) - then + in + (* Check if it is a builtin global - if yes, we ignore it because we + map the definition to one in the standard library *) + let open ExtractBuiltin in + let sname = name_to_simple_name global.name in + let extract = + extract && SimpleNameMap.find_opt sname builtin_globals_map = None + in + if extract then (* We don't wrap global declaration groups between calls to functions [{start, end}_global_decl_group] (which don't exist): global declaration groups are always singletons, so the [extract_global_decl] function takes care of generating the delimiters. *) - Extract.extract_global_decl ctx.extract_ctx fmt global body config.interface + Extract.extract_global_decl ctx fmt global body config.interface (** Utility. @@ -604,14 +637,13 @@ let export_functions_group_scc (fmt : Format.formatter) (config : gen_config) then Some (fun () -> - Extract.extract_fun_decl ctx.extract_ctx fmt kind has_decr_clause - def) + Extract.extract_fun_decl ctx fmt kind has_decr_clause def) else None) decls in let extract_defs = List.filter_map (fun x -> x) extract_defs in if extract_defs <> [] then ( - Extract.start_fun_decl_group ctx.extract_ctx fmt is_rec decls; + Extract.start_fun_decl_group ctx fmt is_rec decls; List.iter (fun f -> f ()) extract_defs; Extract.end_fun_decl_group fmt is_rec decls) @@ -621,82 +653,137 @@ let export_functions_group_scc (fmt : Format.formatter) (config : gen_config) check if the forward and backward functions are mutually recursive. *) let export_functions_group (fmt : Format.formatter) (config : gen_config) - (ctx : gen_ctx) (pure_ls : (bool * pure_fun_translation) list) : unit = - (* Utility to check a function has a decrease clause *) - let has_decreases_clause (def : Pure.fun_decl) : bool = - PureUtils.FunLoopIdSet.mem (def.def_id, def.loop_id) - ctx.functions_with_decreases_clause + (ctx : gen_ctx) (pure_ls : pure_fun_translation list) : unit = + (* Check if the definition are builtin - if yes they must be ignored. + Note that if one definition in the group is builtin, then all the + definitions must be builtin *) + let builtin = + let open ExtractBuiltin in + let funs_map = builtin_funs_map () in + List.map + (fun (trans : pure_fun_translation) -> + let sname = name_to_simple_name trans.fwd.f.basename in + SimpleNameMap.find_opt sname funs_map <> None) + pure_ls in - (* Extract the decrease clauses template bodies *) - if config.extract_template_decreases_clauses then - List.iter - (fun (_, ((fwd, loop_fwds), _)) -> - (* We only generate decreases clauses for the forward functions, because - the termination argument should only depend on the forward inputs. - The backward functions thus use the same decreases clauses as the - forward function. - - Rem.: we might filter backward functions in {!PureMicroPasses}, but - we don't remove forward functions. Instead, we remember if we should - filter those functions at extraction time with a boolean (see the - type of the [pure_ls] input parameter). - *) - let extract_decrease decl = - let has_decr_clause = has_decreases_clause decl in - if has_decr_clause then - match !Config.backend with - | Lean -> - Extract.extract_template_lean_termination_and_decreasing - ctx.extract_ctx fmt decl - | FStar -> - Extract.extract_template_fstar_decreases_clause ctx.extract_ctx - fmt decl - | Coq -> - raise (Failure "Coq doesn't have decreases/termination clauses") - | HOL4 -> - raise - (Failure "HOL4 doesn't have decreases/termination clauses") - in - extract_decrease fwd; - List.iter extract_decrease loop_fwds) - pure_ls; - - (* Concatenate the function definitions, filtering the useless forward - * functions. *) - let decls = - List.concat - (List.map - (fun (keep_fwd, ((fwd, fwd_loops), (back_ls : fun_and_loops list))) -> - let fwd = if keep_fwd then List.append fwd_loops [ fwd ] else [] in - let back : Pure.fun_decl list = - List.concat - (List.map - (fun (back, loop_backs) -> List.append loop_backs [ back ]) - back_ls) - in - List.append fwd back) - pure_ls) - in + if List.exists (fun b -> b) builtin then + (* Sanity check *) + assert (List.for_all (fun b -> b) builtin) + else + (* Utility to check a function has a decrease clause *) + let has_decreases_clause (def : Pure.fun_decl) : bool = + PureUtils.FunLoopIdSet.mem (def.def_id, def.loop_id) + ctx.functions_with_decreases_clause + in - (* Extract the function definitions *) - (if config.extract_fun_decls then - (* Group the mutually recursive definitions *) - let subgroups = ReorderDecls.group_reorder_fun_decls decls in + (* Extract the decrease clauses template bodies *) + if config.extract_template_decreases_clauses then + List.iter + (fun { fwd; _ } -> + (* We only generate decreases clauses for the forward functions, because + the termination argument should only depend on the forward inputs. + The backward functions thus use the same decreases clauses as the + forward function. + + Rem.: we might filter backward functions in {!PureMicroPasses}, but + we don't remove forward functions. Instead, we remember if we should + filter those functions at extraction time with a boolean (see the + type of the [pure_ls] input parameter). + *) + let extract_decrease decl = + let has_decr_clause = has_decreases_clause decl in + if has_decr_clause then + match !Config.backend with + | Lean -> + Extract.extract_template_lean_termination_and_decreasing ctx + fmt decl + | FStar -> + Extract.extract_template_fstar_decreases_clause ctx fmt decl + | Coq -> + raise + (Failure "Coq doesn't have decreases/termination clauses") + | HOL4 -> + raise + (Failure "HOL4 doesn't have decreases/termination clauses") + in + extract_decrease fwd.f; + List.iter extract_decrease fwd.loops) + pure_ls; + + (* Concatenate the function definitions, filtering the useless forward + * functions. *) + let decls = + List.concat + (List.map + (fun { keep_fwd; fwd; backs } -> + let fwd = + if keep_fwd then List.append fwd.loops [ fwd.f ] else [] + in + let backs : Pure.fun_decl list = + List.concat + (List.map + (fun back -> List.append back.loops [ back.f ]) + backs) + in + List.append fwd backs) + pure_ls) + in - (* Extract the subgroups *) - let export_subgroup (is_rec : bool) (decls : Pure.fun_decl list) : unit = - export_functions_group_scc fmt config ctx is_rec decls - in - List.iter (fun (is_rec, decls) -> export_subgroup is_rec decls) subgroups); - - (* Insert unit tests if necessary *) - if config.test_trans_unit_functions then - List.iter - (fun (keep_fwd, ((fwd, _), _)) -> - if keep_fwd then - Extract.extract_unit_test_if_unit_fun ctx.extract_ctx fmt fwd) - pure_ls + (* Extract the function definitions *) + (if config.extract_fun_decls then + (* Group the mutually recursive definitions *) + let subgroups = ReorderDecls.group_reorder_fun_decls decls in + + (* Extract the subgroups *) + let export_subgroup (is_rec : bool) (decls : Pure.fun_decl list) : unit = + export_functions_group_scc fmt config ctx is_rec decls + in + List.iter (fun (is_rec, decls) -> export_subgroup is_rec decls) subgroups); + + (* Insert unit tests if necessary *) + if config.test_trans_unit_functions then + List.iter + (fun trans -> + if trans.keep_fwd then + Extract.extract_unit_test_if_unit_fun ctx fmt trans.fwd.f) + pure_ls + +(** Export a trait declaration. *) +let export_trait_decl (fmt : Format.formatter) (_config : gen_config) + (ctx : gen_ctx) (trait_decl_id : Pure.trait_decl_id) (extract_decl : bool) + (extract_extra_info : bool) : unit = + let trait_decl = T.TraitDeclId.Map.find trait_decl_id ctx.trans_trait_decls in + (* Check if the trait declaration is builtin, in which case we ignore it *) + let open ExtractBuiltin in + let sname = name_to_simple_name trait_decl.name in + if SimpleNameMap.find_opt sname (builtin_trait_decls_map ()) = None then ( + let ctx = { ctx with trait_decl_id = Some trait_decl.def_id } in + if extract_decl then Extract.extract_trait_decl ctx fmt trait_decl; + if extract_extra_info then + Extract.extract_trait_decl_extra_info ctx fmt trait_decl) + else () + +(** Export a trait implementation. *) +let export_trait_impl (fmt : Format.formatter) (_config : gen_config) + (ctx : gen_ctx) (trait_impl_id : Pure.trait_impl_id) : unit = + (* Lookup the definition *) + let trait_impl = T.TraitImplId.Map.find trait_impl_id ctx.trans_trait_impls in + let trait_decl = + Pure.TraitDeclId.Map.find trait_impl.impl_trait.trait_decl_id + ctx.trans_trait_decls + in + (* Check if the trait implementation is builtin *) + let builtin_info = + let open ExtractBuiltin in + let type_sname = name_to_simple_name trait_impl.name in + let trait_sname = name_to_simple_name trait_decl.name in + SimpleNamePairMap.find_opt (type_sname, trait_sname) + (builtin_trait_impls_map ()) + in + match builtin_info with + | None -> Extract.extract_trait_impl ctx fmt trait_impl + | Some _ -> () (** A generic utility to generate the extracted definitions: as we may want to split the definitions between different files (or not), we can control @@ -712,12 +799,19 @@ let extract_definitions (fmt : Format.formatter) (config : gen_config) let export_functions_group = export_functions_group fmt config ctx in let export_global = export_global fmt config ctx in let export_types_group = export_types_group fmt config ctx in + let export_trait_decl_group id = + export_trait_decl fmt config ctx id true false + in + let export_trait_decl_group_extra_info id = + export_trait_decl fmt config ctx id false true + in + let export_trait_impl = export_trait_impl fmt config ctx in let export_state_type () : unit = let kind = if config.interface then ExtractBase.Declared else ExtractBase.Assumed in - Extract.extract_state_type fmt ctx.extract_ctx kind + Extract.extract_state_type fmt ctx kind in let export_decl_group (dg : A.declaration_group) : unit = @@ -725,11 +819,18 @@ let extract_definitions (fmt : Format.formatter) (config : gen_config) | Type (NonRec id) -> if config.extract_types then export_types_group false [ id ] | Type (Rec ids) -> if config.extract_types then export_types_group true ids - | Fun (NonRec id) -> + | Fun (NonRec id) -> ( (* Lookup *) let pure_fun = A.FunDeclId.Map.find id ctx.trans_funs in - (* Translate *) - export_functions_group [ pure_fun ] + (* Special case: we skip trait method *declarations* (we will + extract their type directly in the records we generate for + the trait declarations themselves, there is no point in having + separate type definitions) *) + match pure_fun.fwd.f.Pure.kind with + | TraitMethodDecl _ -> () + | _ -> + (* Translate *) + export_functions_group [ pure_fun ]) | Fun (Rec ids) -> (* General case of mutually recursive functions *) (* Lookup *) @@ -739,11 +840,19 @@ let extract_definitions (fmt : Format.formatter) (config : gen_config) (* Translate *) export_functions_group pure_funs | Global id -> export_global id + | TraitDecl id -> + (* TODO: update to extract groups *) + if config.extract_trait_decls && config.extract_transparent then ( + export_trait_decl_group id; + export_trait_decl_group_extra_info id) + | TraitImpl id -> + if config.extract_trait_impls && config.extract_transparent then + export_trait_impl id in (* If we need to export the state type: we try to export it after we defined * the type definitions, because if the user wants to define a model for the - * type, he might want to reuse those in the state type. + * type, they might want to reuse those in the state type. * More specifically: if we extract functions in the same file as the type, * we have no choice but to define the state type before the functions, * because they may reuse this state type: in this case, we define/declare @@ -752,37 +861,10 @@ let extract_definitions (fmt : Format.formatter) (config : gen_config) if config.extract_state_type && config.extract_fun_decls then export_state_type (); - (* Obsolete: (TODO: remove) For Lean we parameterize the entire development by a section - variable called opaque_defs, of type OpaqueDefs. The code below emits the type - definition for OpaqueDefs, which is a structure, in which each field is one of the - functions marked as Opaque. We emit the `structure ...` bit here, then rely on - `extract_fun_decl` to be aware of this, and skip the keyword (e.g. "axiom" or "val") - so as to generate valid syntax for records. - - We also generate such a structure only if there actually are opaque definitions. *) - let wrap_in_sig = - config.extract_opaque && config.extract_fun_decls - && !Config.wrap_opaque_in_sig - && - let _, opaque_funs = module_has_opaque_decls ctx in - opaque_funs - in - if wrap_in_sig then ( - (* We change the name of the structure depending on whether we *only* - extract opaque definitions, or if we extract all definitions *) - let struct_name = - if config.extract_transparent then "Definitions" else "OpaqueDefs" - in - Format.pp_print_break fmt 0 0; - Format.pp_open_vbox fmt ctx.extract_ctx.indent_incr; - Format.pp_print_string fmt ("structure " ^ struct_name ^ " where"); - Format.pp_print_break fmt 0 0); List.iter export_decl_group ctx.crate.declarations; if config.extract_state_type && not config.extract_fun_decls then - export_state_type (); - - if wrap_in_sig then Format.pp_close_box fmt () + export_state_type () type extract_file_info = { filename : string; @@ -904,7 +986,9 @@ let extract_file (config : gen_config) (ctx : gen_ctx) (fi : extract_file_info) let translate_crate (filename : string) (dest_dir : string) (crate : A.crate) : unit = (* Translate the module to the pure AST *) - let trans_ctx, trans_types, trans_funs = translate_crate_to_pure crate in + let trans_ctx, trans_types, trans_funs, trans_trait_decls, trans_trait_impls = + translate_crate_to_pure crate + in (* Initialize the extraction context - for now we extract only to F*. * We initialize the names map by registering the keywords used in the @@ -916,41 +1000,27 @@ let translate_crate (filename : string) (dest_dir : string) (crate : A.crate) : in (* Initialize the names map (we insert the names of the "primitives" declarations, and insert the names of the local declarations later) *) - let mk_formatter_and_names_map = Extract.mk_formatter_and_names_map in - let fmt, names_map = - mk_formatter_and_names_map trans_ctx crate.name + let fmt, names_maps = + Extract.mk_formatter_and_names_maps trans_ctx crate.name variant_concatenate_type_name in - (* Put everything in the context *) - let ctx = - { - ExtractBase.trans_ctx; - names_map; - unsafe_names_map = { id_to_name = ExtractBase.IdMap.empty }; - fmt; - indent_incr = 2; - use_opaque_pre = !Config.split_files; - use_dep_ite = !Config.backend = Lean && !Config.extract_decreases_clauses; - fun_name_info = PureUtils.RegularFunIdMap.empty; - } - in (* We need to compute which functions are recursive, in order to know * whether we should generate a decrease clause or not. *) let rec_functions = List.map - (fun (_, ((fwd, loop_fwds), _)) -> - let fwd = - if fwd.Pure.signature.info.effect_info.is_rec then - [ (fwd.def_id, None) ] + (fun { fwd; _ } -> + let fwd_f = + if fwd.f.Pure.signature.info.effect_info.is_rec then + [ (fwd.f.def_id, None) ] else [] in let loop_fwds = List.map (fun (def : Pure.fun_decl) -> [ (def.def_id, def.loop_id) ]) - loop_fwds + fwd.loops in - fwd :: loop_fwds) + fwd_f :: loop_fwds) trans_funs in let rec_functions : PureUtils.fun_loop_id list = @@ -958,22 +1028,70 @@ let translate_crate (filename : string) (dest_dir : string) (crate : A.crate) : in let rec_functions = PureUtils.FunLoopIdSet.of_list rec_functions in - (* Register unique names for all the top-level types, globals and functions. + (* Put the translated definitions in maps *) + let trans_types = + Pure.TypeDeclId.Map.of_list + (List.map (fun (d : Pure.type_decl) -> (d.def_id, d)) trans_types) + in + let trans_funs : pure_fun_translation A.FunDeclId.Map.t = + A.FunDeclId.Map.of_list + (List.map + (fun (trans : pure_fun_translation) -> (trans.fwd.f.def_id, trans)) + trans_funs) + in + + (* Put everything in the context *) + let ctx = + let trans_trait_decls = + T.TraitDeclId.Map.of_list + (List.map + (fun (d : Pure.trait_decl) -> (d.def_id, d)) + trans_trait_decls) + in + let trans_trait_impls = + T.TraitImplId.Map.of_list + (List.map + (fun (d : Pure.trait_impl) -> (d.def_id, d)) + trans_trait_impls) + in + { + ExtractBase.crate; + trans_ctx; + names_maps; + fmt; + indent_incr = 2; + use_dep_ite = !Config.backend = Lean && !Config.extract_decreases_clauses; + fun_name_info = PureUtils.RegularFunIdMap.empty; + trait_decl_id = None (* None by default *); + is_provided_method = false (* false by default *); + trans_trait_decls; + trans_trait_impls; + trans_types; + trans_funs; + functions_with_decreases_clause = rec_functions; + types_filter_type_args_map = Pure.TypeDeclId.Map.empty; + funs_filter_type_args_map = Pure.FunDeclId.Map.empty; + trait_impls_filter_type_args_map = Pure.TraitImplId.Map.empty; + } + in + + (* Register unique names for all the top-level types, globals, functions... * Note that the order in which we generate the names doesn't matter: * we just need to generate a mapping from identifier to name, and make * sure there are no name clashes. *) let ctx = List.fold_left (fun ctx def -> Extract.extract_type_decl_register_names ctx def) - ctx trans_types + ctx + (Pure.TypeDeclId.Map.values trans_types) in let ctx = List.fold_left - (fun ctx (keep_fwd, defs) -> + (fun ctx (trans : pure_fun_translation) -> (* If requested by the user, register termination measures and decreases proofs for all the recursive functions *) - let fwd_def = fst (fst defs) in + let fwd_def = trans.fwd.f in let gen_decr_clause (def : Pure.fun_decl) = !Config.extract_decreases_clauses && PureUtils.FunLoopIdSet.mem @@ -984,10 +1102,9 @@ let translate_crate (filename : string) (dest_dir : string) (crate : A.crate) : * those are handled later *) let is_global = fwd_def.Pure.is_global_decl_body in if is_global then ctx - else - Extract.extract_fun_decl_register_names ctx keep_fwd gen_decr_clause - defs) - ctx trans_funs + else Extract.extract_fun_decl_register_names ctx gen_decr_clause trans) + ctx + (A.FunDeclId.Map.values trans_funs) in let ctx = @@ -995,6 +1112,16 @@ let translate_crate (filename : string) (dest_dir : string) (crate : A.crate) : (A.GlobalDeclId.Map.values crate.globals) in + let ctx = + List.fold_left Extract.extract_trait_decl_register_names ctx + trans_trait_decls + in + + let ctx = + List.fold_left Extract.extract_trait_impl_register_names ctx + trans_trait_impls + in + (* Open the output file *) (* First compute the filename by replacing the extension and converting the * case (rust module names are snake case) *) @@ -1023,19 +1150,6 @@ let translate_crate (filename : string) (dest_dir : string) (crate : A.crate) : (namespace, crate_name, Filename.concat dest_dir crate_name) in - (* Put the translated definitions in maps *) - let trans_types = - Pure.TypeDeclId.Map.of_list - (List.map (fun (d : Pure.type_decl) -> (d.def_id, d)) trans_types) - in - let trans_funs = - A.FunDeclId.Map.of_list - (List.map - (fun ((keep_fwd, (fd, bdl)) : bool * pure_fun_translation) -> - ((fst fd).def_id, (keep_fwd, (fd, bdl)))) - trans_funs) - in - let mkdir_if dest_dir = if not (Sys.file_exists dest_dir) then ( log#linfo (lazy ("Creating missing directory: " ^ dest_dir)); @@ -1091,16 +1205,6 @@ let translate_crate (filename : string) (dest_dir : string) (crate : A.crate) : in (* Extract the file(s) *) - let gen_ctx = - { - crate; - extract_ctx = ctx; - trans_types; - trans_funs; - functions_with_decreases_clause = rec_functions; - } - in - let module_delimiter = match !Config.backend with | FStar -> "." @@ -1136,6 +1240,8 @@ let translate_crate (filename : string) (dest_dir : string) (crate : A.crate) : extract_decreases_clauses = !Config.extract_decreases_clauses; extract_template_decreases_clauses = false; extract_fun_decls = false; + extract_trait_decls = false; + extract_trait_impls = false; extract_transparent = true; extract_opaque = false; extract_state_type = false; @@ -1147,7 +1253,9 @@ let translate_crate (filename : string) (dest_dir : string) (crate : A.crate) : (* Check if there are opaque types and functions - in which case we need * to split *) - let has_opaque_types, has_opaque_funs = module_has_opaque_decls gen_ctx in + let has_opaque_types, has_opaque_funs = + crate_has_opaque_non_builtin_decls ctx true + in let has_opaque_types = has_opaque_types || !Config.use_state in (* Extract the types *) @@ -1168,6 +1276,7 @@ let translate_crate (filename : string) (dest_dir : string) (crate : A.crate) : { base_gen_config with extract_types = true; + extract_trait_decls = true; extract_opaque = true; extract_state_type = !Config.use_state; interface = has_opaque_types; @@ -1186,7 +1295,7 @@ let translate_crate (filename : string) (dest_dir : string) (crate : A.crate) : custom_includes = []; } in - extract_file types_config gen_ctx file_info; + extract_file types_config ctx file_info; (* Extract the template clauses *) (if needs_clauses_module && !Config.extract_template_decreases_clauses then @@ -1214,9 +1323,9 @@ let translate_crate (filename : string) (dest_dir : string) (crate : A.crate) : custom_includes = []; } in - extract_file template_clauses_config gen_ctx file_info); + extract_file template_clauses_config ctx file_info); - (* Extract the opaque functions, if needed *) + (* Extract the opaque declarations, if needed *) let opaque_funs_module = if has_opaque_funs then ( (* In the case of Lean we generate a template file *) @@ -1244,17 +1353,13 @@ let translate_crate (filename : string) (dest_dir : string) (crate : A.crate) : { base_gen_config with extract_fun_decls = true; + extract_trait_impls = true; + extract_globals = true; extract_transparent = false; extract_opaque = true; interface = true; } in - let gen_ctx = - { - gen_ctx with - extract_ctx = { gen_ctx.extract_ctx with use_opaque_pre = false }; - } - in let file_info = { filename = opaque_filename; @@ -1268,7 +1373,7 @@ let translate_crate (filename : string) (dest_dir : string) (crate : A.crate) : custom_includes = [ types_module ]; } in - extract_file opaque_config gen_ctx file_info; + extract_file opaque_config ctx file_info; (* Return the additional dependencies *) [ opaque_imported_module ]) else [] @@ -1281,6 +1386,7 @@ let translate_crate (filename : string) (dest_dir : string) (crate : A.crate) : { base_gen_config with extract_fun_decls = true; + extract_trait_impls = true; extract_globals = true; test_trans_unit_functions = !Config.test_trans_unit_functions; } @@ -1307,7 +1413,7 @@ let translate_crate (filename : string) (dest_dir : string) (crate : A.crate) : [ types_module ] @ opaque_funs_module @ clauses_module; } in - extract_file fun_config gen_ctx file_info) + extract_file fun_config ctx file_info) else let gen_config = { @@ -1316,6 +1422,8 @@ let translate_crate (filename : string) (dest_dir : string) (crate : A.crate) : extract_template_decreases_clauses = !Config.extract_template_decreases_clauses; extract_fun_decls = true; + extract_trait_decls = true; + extract_trait_impls = true; extract_transparent = true; extract_opaque = true; extract_state_type = !Config.use_state; @@ -1337,7 +1445,7 @@ let translate_crate (filename : string) (dest_dir : string) (crate : A.crate) : custom_includes = []; } in - extract_file gen_config gen_ctx file_info); + extract_file gen_config ctx file_info); (* Generate the build file *) match !Config.backend with diff --git a/compiler/TranslateCore.ml b/compiler/TranslateCore.ml index ba5e237b..3427fd43 100644 --- a/compiler/TranslateCore.ml +++ b/compiler/TranslateCore.ml @@ -10,64 +10,69 @@ module FA = FunsAnalysis (** The local logger *) let log = L.translate_log -type type_context = C.type_context [@@deriving show] - -type fun_context = { - fun_decls : A.fun_decl A.FunDeclId.Map.t; - fun_infos : FA.fun_info A.FunDeclId.Map.t; -} -[@@deriving show] +type trans_ctx = C.decls_ctx [@@deriving show] +type fun_and_loops = { f : Pure.fun_decl; loops : Pure.fun_decl list } +type pure_fun_translation_no_loops = Pure.fun_decl * Pure.fun_decl list -type global_context = C.global_context [@@deriving show] +type pure_fun_translation = { + keep_fwd : bool; + (** Should we extract the forward function? -type trans_ctx = { - type_context : type_context; - fun_context : fun_context; - global_context : global_context; + If the forward function returns `()` and there is exactly one + backward function, we may merge the forward into the backward + function and thus don't extract the forward function)? + *) + fwd : fun_and_loops; + backs : fun_and_loops list; } -type fun_and_loops = Pure.fun_decl * Pure.fun_decl list -type pure_fun_translation_no_loops = Pure.fun_decl * Pure.fun_decl list -type pure_fun_translation = fun_and_loops * fun_and_loops list +let trans_ctx_to_type_formatter (ctx : trans_ctx) + (type_params : Pure.type_var list) + (const_generic_params : Pure.const_generic_var list) : + PrintPure.type_formatter = + let type_decls = ctx.type_ctx.type_decls in + let global_decls = ctx.global_ctx.global_decls in + let trait_decls = ctx.trait_decls_ctx.trait_decls in + let trait_impls = ctx.trait_impls_ctx.trait_impls in + PrintPure.mk_type_formatter type_decls global_decls trait_decls trait_impls + type_params const_generic_params let type_decl_to_string (ctx : trans_ctx) (def : Pure.type_decl) : string = - let type_params = def.type_params in - let cg_params = def.const_generic_params in - let type_decls = ctx.type_context.type_decls in - let global_decls = ctx.global_context.global_decls in + let generics = def.generics in let fmt = - PrintPure.mk_type_formatter type_decls global_decls type_params cg_params + trans_ctx_to_type_formatter ctx generics.types generics.const_generics in PrintPure.type_decl_to_string fmt def let type_id_to_string (ctx : trans_ctx) (id : Pure.TypeDeclId.id) : string = Print.fun_name_to_string - (Pure.TypeDeclId.Map.find id ctx.type_context.type_decls).name + (Pure.TypeDeclId.Map.find id ctx.type_ctx.type_decls).name + +let trans_ctx_to_ast_formatter (ctx : trans_ctx) + (type_params : Pure.type_var list) + (const_generic_params : Pure.const_generic_var list) : + PrintPure.ast_formatter = + let type_decls = ctx.type_ctx.type_decls in + let fun_decls = ctx.fun_ctx.fun_decls in + let global_decls = ctx.global_ctx.global_decls in + let trait_decls = ctx.trait_decls_ctx.trait_decls in + let trait_impls = ctx.trait_impls_ctx.trait_impls in + PrintPure.mk_ast_formatter type_decls fun_decls global_decls trait_decls + trait_impls type_params const_generic_params let fun_sig_to_string (ctx : trans_ctx) (sg : Pure.fun_sig) : string = - let type_params = sg.type_params in - let cg_params = sg.const_generic_params in - let type_decls = ctx.type_context.type_decls in - let fun_decls = ctx.fun_context.fun_decls in - let global_decls = ctx.global_context.global_decls in + let generics = sg.generics in let fmt = - PrintPure.mk_ast_formatter type_decls fun_decls global_decls type_params - cg_params + trans_ctx_to_ast_formatter ctx generics.types generics.const_generics in PrintPure.fun_sig_to_string fmt sg let fun_decl_to_string (ctx : trans_ctx) (def : Pure.fun_decl) : string = - let type_params = def.signature.type_params in - let cg_params = def.signature.const_generic_params in - let type_decls = ctx.type_context.type_decls in - let fun_decls = ctx.fun_context.fun_decls in - let global_decls = ctx.global_context.global_decls in + let generics = def.signature.generics in let fmt = - PrintPure.mk_ast_formatter type_decls fun_decls global_decls type_params - cg_params + trans_ctx_to_ast_formatter ctx generics.types generics.const_generics in PrintPure.fun_decl_to_string fmt def let fun_decl_id_to_string (ctx : trans_ctx) (id : A.FunDeclId.id) : string = - Print.fun_name_to_string - (A.FunDeclId.Map.find id ctx.fun_context.fun_decls).name + Print.fun_name_to_string (A.FunDeclId.Map.find id ctx.fun_ctx.fun_decls).name diff --git a/compiler/TypesAnalysis.ml b/compiler/TypesAnalysis.ml index 925f6d39..38d350b1 100644 --- a/compiler/TypesAnalysis.ml +++ b/compiler/TypesAnalysis.ml @@ -14,11 +14,10 @@ type expl_info = subtype_info [@@deriving show] type type_borrows_info = { contains_static : bool; - (** Does the type (transitively) contains a static borrow? *) - contains_borrow : bool; - (** Does the type (transitively) contains a borrow? *) + (** Does the type (transitively) contain a static borrow? *) + contains_borrow : bool; (** Does the type (transitively) contain a borrow? *) contains_nested_borrows : bool; - (** Does the type (transitively) contains nested borrows? *) + (** Does the type (transitively) contain nested borrows? *) contains_borrow_under_mut : bool; } [@@deriving show] @@ -61,7 +60,7 @@ let initialize_g_type_info (param_infos : 'p) : 'p g_type_info = let initialize_type_decl_info (def : type_decl) : type_decl_info = let param_info = { under_borrow = false; under_mut_borrow = false } in - let param_infos = List.map (fun _ -> param_info) def.type_params in + let param_infos = List.map (fun _ -> param_info) def.generics.types in initialize_g_type_info param_infos let type_decl_info_to_partial_type_info (info : type_decl_info) : @@ -122,7 +121,7 @@ let analyze_full_ty (r_is_static : 'r -> bool) (updated : bool ref) let rec analyze (expl_info : expl_info) (ty_info : partial_type_info) (ty : 'r ty) : partial_type_info = match ty with - | Literal _ | Never -> ty_info + | Literal _ | Never | TraitType _ -> ty_info | TypeVar var_id -> ( (* Update the information for the proper parameter, if necessary *) match ty_info.param_infos with @@ -169,22 +168,21 @@ let analyze_full_ty (r_is_static : 'r -> bool) (updated : bool ref) in (* Continue exploring *) analyze expl_info ty_info rty - | Adt - ( (Tuple | Assumed (Box | Vec | Option | Slice | Array | Str | Range)), - _, - tys, - _ ) -> + | RawPtr (rty, _) -> + (* TODO: not sure what to do here *) + analyze expl_info ty_info rty + | Adt ((Tuple | Assumed (Box | Slice | Array | Str)), generics) -> (* Nothing to update: just explore the type parameters *) List.fold_left (fun ty_info ty -> analyze expl_info ty_info ty) - ty_info tys - | Adt (AdtId adt_id, regions, tys, _cgs) -> + ty_info generics.types + | Adt (AdtId adt_id, generics) -> (* Lookup the information for this type definition *) let adt_info = TypeDeclId.Map.find adt_id infos in (* Update the type info with the information from the adt *) let ty_info = update_ty_info ty_info adt_info.borrows_info in (* Check if 'static appears in the region parameters *) - let found_static = List.exists r_is_static regions in + let found_static = List.exists r_is_static generics.regions in let borrows_info = ty_info.borrows_info in let borrows_info = { @@ -196,7 +194,7 @@ let analyze_full_ty (r_is_static : 'r -> bool) (updated : bool ref) let ty_info = { ty_info with borrows_info } in (* For every instantiated type parameter: update the exploration info * then explore the type *) - let params_tys = List.combine adt_info.param_infos tys in + let params_tys = List.combine adt_info.param_infos generics.types in let ty_info = List.fold_left (fun ty_info (param_info, ty) -> @@ -235,6 +233,14 @@ let analyze_full_ty (r_is_static : 'r -> bool) (updated : bool ref) in (* Return *) ty_info + | Arrow (inputs, output) -> + (* Just dive into the arrow *) + let ty_info = + List.fold_left + (fun ty_info ty -> analyze expl_info ty_info ty) + ty_info inputs + in + analyze expl_info ty_info output in (* Explore *) analyze expl_info_init ty_info ty diff --git a/compiler/Values.ml b/compiler/Values.ml index d884c319..de27e7a9 100644 --- a/compiler/Values.ml +++ b/compiler/Values.ml @@ -52,6 +52,10 @@ type sv_kind = (** The result of a loop join (when computing loop fixed points) *) | Aggregate (** A symbolic value we introduce in place of an aggregate value *) + | ConstGeneric + (** A symbolic value we introduce when using a const generic as a value *) + | TraitConst + (** A symbolic value we introduce when evaluating a trait associated constant *) [@@deriving show, ord] (** Ancestor for {!symbolic_value} iter visitor *) diff --git a/compiler/dune b/compiler/dune index 6785cad4..648c7325 100644 --- a/compiler/dune +++ b/compiler/dune @@ -12,6 +12,7 @@ (pps ppx_deriving.show ppx_deriving.ord visitors.ppx)) (libraries charon core_unix unionFind ocamlgraph) (modules + AssociatedTypes Assumed Collections Config @@ -22,6 +23,8 @@ ExpressionsUtils Extract ExtractBase + ExtractBuiltin + ExtractTypes FunsAnalysis Identifiers InterpreterBorrowsCore @@ -90,4 +93,4 @@ -g ;-dsource -warn-error - -5-8-9-11-14-33-20-21-26-27-39))) + -5@8-9-11-14-33-20-21-26-27-39))) |