summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--compiler/AssociatedTypes.ml90
-rw-r--r--compiler/Config.ml8
-rw-r--r--compiler/Print.ml22
-rw-r--r--compiler/SymbolicToPure.ml77
4 files changed, 157 insertions, 40 deletions
diff --git a/compiler/AssociatedTypes.ml b/compiler/AssociatedTypes.ml
index c4a9538d..bce0fb11 100644
--- a/compiler/AssociatedTypes.ml
+++ b/compiler/AssociatedTypes.ml
@@ -15,6 +15,7 @@ module C = Contexts
module Subst = Substitute
module L = Logging
module UF = UnionFind
+module PA = Print.EvalCtxLlbcAst
(** The local logger *)
let log = L.associated_types_log
@@ -111,6 +112,10 @@ type 'r norm_ctx = {
get_ty_repr : 'r C.trait_type_ref -> 'r T.ty option;
convert_ety : T.ety -> 'r T.ty;
convert_etrait_ref : T.etrait_ref -> 'r T.trait_ref;
+ ty_to_string : 'r T.ty -> string;
+ trait_ref_to_string : 'r T.trait_ref -> string;
+ trait_instance_id_to_string : 'r T.trait_instance_id -> string;
+ pp_r : Format.formatter -> 'r -> unit;
}
(** Normalize a type by simplyfying the references to trait associated types
@@ -118,6 +123,7 @@ type 'r norm_ctx = {
enforced by local clauses (i.e., `where Trait1::T = Trait2::U`. *)
let rec ctx_normalize_ty : 'r. 'r norm_ctx -> 'r T.ty -> 'r T.ty =
fun ctx ty ->
+ log#ldebug (lazy ("ctx_normalize_ty: " ^ ctx.ty_to_string ty));
match ty with
| T.Adt (id, generics) -> Adt (id, ctx_normalize_generic_args ctx generics)
| TypeVar _ | Literal _ | Never -> ty
@@ -125,19 +131,56 @@ let rec ctx_normalize_ty : 'r. 'r norm_ctx -> 'r T.ty -> 'r T.ty =
let ty = ctx_normalize_ty ctx ty in
T.Ref (r, ty, rkind)
| TraitType (trait_ref, generics, type_name) -> (
+ log#ldebug
+ (lazy
+ ("ctx_normalize_ty: trait type: " ^ ctx.ty_to_string ty
+ ^ "\n- trait_ref: "
+ ^ ctx.trait_ref_to_string trait_ref
+ ^ "\n- raw trait ref: "
+ ^ T.show_trait_ref ctx.pp_r trait_ref));
(* Normalize and attempt to project the type from the trait ref *)
let trait_ref = ctx_normalize_trait_ref ctx trait_ref in
let generics = ctx_normalize_generic_args ctx generics in
let ty : 'r T.ty =
match trait_ref.trait_id with
- | T.TraitRef { T.trait_id = T.TraitImpl impl_id; generics; _ } ->
+ | T.TraitRef
+ { T.trait_id = T.TraitImpl impl_id; generics = ref_generics; _ } ->
+ assert (ref_generics = TypesUtils.mk_empty_generic_args);
+ log#ldebug
+ (lazy
+ ("ctx_normalize_ty: trait type: trait ref: "
+ ^ ctx.ty_to_string ty));
(* Lookup the implementation *)
let trait_impl = C.ctx_lookup_trait_impl ctx.ctx impl_id in
(* Lookup the type *)
let ty = snd (List.assoc type_name trait_impl.types) in
(* Annoying: convert etype to an stype - TODO: hwo to avoid that? *)
let ty : T.sty = TypesUtils.ety_no_regions_to_gr_ty ty in
- (* Substitute - annoying: we can't use *)
+ (* Substitute *)
+ let tr_self = T.UnknownTrait __FUNCTION__ in
+ let subst =
+ Subst.make_subst_from_generics_no_regions trait_impl.generics
+ generics tr_self
+ in
+ let ty = Subst.ty_substitute subst ty in
+ (* Reconvert *)
+ let ty : 'r T.ty = ctx.convert_ety (Subst.erase_regions ty) in
+ (* Normalize *)
+ ctx_normalize_ty ctx ty
+ | T.TraitImpl impl_id ->
+ (* This happens. This doesn't come from the substituations
+ performed by Aeneas (the [TraitImpl] would be wrapped in a
+ [TraitRef] but from non-normalized traits translated from
+ the Rustc AST.
+ TODO: factor out with the branch above.
+ *)
+ (* Lookup the implementation *)
+ let trait_impl = C.ctx_lookup_trait_impl ctx.ctx impl_id in
+ (* Lookup the type *)
+ let ty = snd (List.assoc type_name trait_impl.types) in
+ (* Annoying: convert etype to an stype - TODO: hwo to avoid that? *)
+ let ty : T.sty = TypesUtils.ety_no_regions_to_gr_ty ty in
+ (* Substitute *)
let tr_self = T.UnknownTrait __FUNCTION__ in
let subst =
Subst.make_subst_from_generics_no_regions trait_impl.generics
@@ -149,6 +192,13 @@ let rec ctx_normalize_ty : 'r. 'r norm_ctx -> 'r T.ty -> 'r T.ty =
(* Normalize *)
ctx_normalize_ty ctx ty
| _ ->
+ log#ldebug
+ (lazy
+ ("ctx_normalize_ty: trait type: not a trait ref: "
+ ^ ctx.ty_to_string ty ^ "\n- trait_ref: "
+ ^ ctx.trait_ref_to_string trait_ref
+ ^ "\n- raw trait ref: "
+ ^ T.show_trait_ref ctx.pp_r trait_ref));
(* We can't project *)
assert (trait_instance_id_is_local_clause trait_ref.trait_id);
T.TraitType (trait_ref, generics, type_name)
@@ -307,11 +357,31 @@ and ctx_normalize_generic_args (ctx : 'r norm_ctx)
and ctx_normalize_trait_ref (ctx : 'r norm_ctx) (trait_ref : 'r T.trait_ref) :
'r T.trait_ref =
+ log#ldebug
+ (lazy
+ ("ctx_normalize_trait_ref: "
+ ^ ctx.trait_ref_to_string trait_ref
+ ^ "\n- raw trait ref:\n"
+ ^ T.show_trait_ref ctx.pp_r trait_ref));
let { T.trait_id; generics; trait_decl_ref } = trait_ref in
- let trait_id, _ = ctx_normalize_trait_instance_id ctx trait_id in
- let generics = ctx_normalize_generic_args ctx generics in
- let trait_decl_ref = ctx_normalize_trait_decl_ref ctx trait_decl_ref in
- { T.trait_id; generics; trait_decl_ref }
+ (* Check if the id is an impl, otherwise normalize it *)
+ let trait_id, norm_trait_ref = ctx_normalize_trait_instance_id ctx trait_id in
+ match norm_trait_ref with
+ | None ->
+ log#ldebug
+ (lazy
+ ("ctx_normalize_trait_ref: no norm: "
+ ^ ctx.trait_instance_id_to_string trait_id));
+ let generics = ctx_normalize_generic_args ctx generics in
+ let trait_decl_ref = ctx_normalize_trait_decl_ref ctx trait_decl_ref in
+ { T.trait_id; generics; trait_decl_ref }
+ | Some trait_ref ->
+ log#ldebug
+ (lazy
+ ("ctx_normalize_trait_ref: normalized to: "
+ ^ ctx.trait_ref_to_string trait_ref));
+ assert (generics = TypesUtils.mk_empty_generic_args);
+ trait_ref
(* Not sure this one is really necessary *)
and ctx_normalize_trait_decl_ref (ctx : 'r norm_ctx)
@@ -335,6 +405,10 @@ let mk_rnorm_ctx (ctx : C.eval_ctx) : T.RegionId.id T.region norm_ctx =
get_ty_repr;
convert_ety = TypesUtils.ety_no_regions_to_rty;
convert_etrait_ref = TypesUtils.etrait_ref_no_regions_to_gr_trait_ref;
+ ty_to_string = PA.rty_to_string ctx;
+ trait_ref_to_string = PA.rtrait_ref_to_string ctx;
+ trait_instance_id_to_string = PA.rtrait_instance_id_to_string ctx;
+ pp_r = T.pp_region T.pp_region_id;
}
let mk_enorm_ctx (ctx : C.eval_ctx) : T.erased_region norm_ctx =
@@ -344,6 +418,10 @@ let mk_enorm_ctx (ctx : C.eval_ctx) : T.erased_region norm_ctx =
get_ty_repr;
convert_ety = (fun x -> x);
convert_etrait_ref = (fun x -> x);
+ ty_to_string = PA.ety_to_string ctx;
+ trait_ref_to_string = PA.etrait_ref_to_string ctx;
+ trait_instance_id_to_string = PA.etrait_instance_id_to_string ctx;
+ pp_r = T.pp_erased_region;
}
let ctx_normalize_rty (ctx : C.eval_ctx) (ty : T.rty) : T.rty =
diff --git a/compiler/Config.ml b/compiler/Config.ml
index ccbb4c75..508746d9 100644
--- a/compiler/Config.ml
+++ b/compiler/Config.ml
@@ -331,3 +331,11 @@ let record_fields_short_names = ref false
and to account for type constraints (like [fn f<T : Foo>(...) where T::bar = usize]).
*)
let parameterize_trait_types = ref false
+
+(** For sanity check: type check the generated pure code (activates checks in
+ several places).
+
+ TODO: deactivated for now because we need to implement the normalization of
+ trait associated types in the pure code.
+ *)
+let type_check_pure_code = ref false
diff --git a/compiler/Print.ml b/compiler/Print.ml
index 92743bc1..93a1f970 100644
--- a/compiler/Print.ml
+++ b/compiler/Print.ml
@@ -651,6 +651,28 @@ module EvalCtxLlbcAst = struct
let fmt = PC.ctx_to_rtype_formatter fmt in
PT.rty_to_string fmt t
+ let etrait_ref_to_string (ctx : C.eval_ctx) (x : T.etrait_ref) : string =
+ let fmt = PC.eval_ctx_to_ctx_formatter ctx in
+ let fmt = PC.ctx_to_etype_formatter fmt in
+ PT.etrait_ref_to_string fmt x
+
+ let rtrait_ref_to_string (ctx : C.eval_ctx) (x : T.rtrait_ref) : string =
+ let fmt = PC.eval_ctx_to_ctx_formatter ctx in
+ let fmt = PC.ctx_to_rtype_formatter fmt in
+ PT.rtrait_ref_to_string fmt x
+
+ let etrait_instance_id_to_string (ctx : C.eval_ctx) (x : T.etrait_instance_id)
+ : string =
+ let fmt = PC.eval_ctx_to_ctx_formatter ctx in
+ let fmt = PC.ctx_to_etype_formatter fmt in
+ PT.etrait_instance_id_to_string fmt x
+
+ let rtrait_instance_id_to_string (ctx : C.eval_ctx) (x : T.rtrait_instance_id)
+ : string =
+ let fmt = PC.eval_ctx_to_ctx_formatter ctx in
+ let fmt = PC.ctx_to_rtype_formatter fmt in
+ PT.rtrait_instance_id_to_string fmt x
+
let egeneric_args_to_string (ctx : C.eval_ctx) (x : T.egeneric_args) : string
=
let fmt = PC.eval_ctx_to_ctx_formatter ctx in
diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml
index 4c5b99c3..2e0e9862 100644
--- a/compiler/SymbolicToPure.ml
+++ b/compiler/SymbolicToPure.ml
@@ -687,8 +687,9 @@ let type_check_pattern (ctx : bs_ctx) (v : typed_pattern) : unit =
()
let type_check_texpression (ctx : bs_ctx) (e : texpression) : unit =
- let ctx = mk_type_check_ctx ctx in
- PureTypeCheck.check_texpression ctx e
+ if !Config.type_check_pure_code then
+ let ctx = mk_type_check_ctx ctx in
+ PureTypeCheck.check_texpression ctx e
let translate_fun_id_or_trait_method_ref (ctx : bs_ctx)
(id : A.fun_id_or_trait_method_ref) : fun_id_or_trait_method_ref =
@@ -1817,9 +1818,11 @@ and translate_end_abstraction_synth_input (ectx : C.eval_ctx) (abs : V.abs)
(* Group the two lists *)
let variables_values = List.combine given_back_variables consumed_values in
(* Sanity check: the two lists match (same types) *)
- List.iter
- (fun (var, v) -> assert ((var : var).ty = (v : texpression).ty))
- variables_values;
+ (* TODO: normalize the types *)
+ if !Config.type_check_pure_code then
+ List.iter
+ (fun (var, v) -> assert ((var : var).ty = (v : texpression).ty))
+ variables_values;
(* Translate the next expression *)
let next_e = translate_expression e ctx in
(* Generate the assignemnts *)
@@ -1892,31 +1895,35 @@ and translate_end_abstraction_fun_call (ectx : C.eval_ctx) (abs : V.abs)
| Some nstate -> mk_simpl_tuple_pattern [ nstate; output ]
in
(* Sanity check: there is the proper number of inputs and outputs, and they have the proper type *)
- let _ =
- let inst_sg = get_instantiated_fun_sig fun_id (Some rg_id) generics ctx in
- log#ldebug
- (lazy
- ("\n- fun_id: " ^ A.show_fun_id fun_id ^ "\n- inputs ("
- ^ string_of_int (List.length inputs)
- ^ "): "
- ^ String.concat ", " (List.map (texpression_to_string ctx) inputs)
- ^ "\n- inst_sg.inputs ("
- ^ string_of_int (List.length inst_sg.inputs)
- ^ "): "
- ^ String.concat ", " (List.map (ty_to_string ctx) inst_sg.inputs)));
- List.iter
- (fun (x, ty) -> assert ((x : texpression).ty = ty))
- (List.combine inputs inst_sg.inputs);
- log#ldebug
- (lazy
- ("\n- outputs: "
- ^ string_of_int (List.length outputs)
- ^ "\n- expected outputs: "
- ^ string_of_int (List.length inst_sg.doutputs)));
- List.iter
- (fun (x, ty) -> assert ((x : typed_pattern).ty = ty))
- (List.combine outputs inst_sg.doutputs)
- in
+ (if (* TODO: normalize the types *) !Config.type_check_pure_code then
+ match fun_id with
+ | A.FunId fun_id ->
+ let inst_sg =
+ get_instantiated_fun_sig fun_id (Some rg_id) generics ctx
+ in
+ log#ldebug
+ (lazy
+ ("\n- fun_id: " ^ A.show_fun_id fun_id ^ "\n- inputs ("
+ ^ string_of_int (List.length inputs)
+ ^ "): "
+ ^ String.concat ", " (List.map (texpression_to_string ctx) inputs)
+ ^ "\n- inst_sg.inputs ("
+ ^ string_of_int (List.length inst_sg.inputs)
+ ^ "): "
+ ^ String.concat ", " (List.map (ty_to_string ctx) inst_sg.inputs)));
+ List.iter
+ (fun (x, ty) -> assert ((x : texpression).ty = ty))
+ (List.combine inputs inst_sg.inputs);
+ log#ldebug
+ (lazy
+ ("\n- outputs: "
+ ^ string_of_int (List.length outputs)
+ ^ "\n- expected outputs: "
+ ^ string_of_int (List.length inst_sg.doutputs)));
+ List.iter
+ (fun (x, ty) -> assert ((x : typed_pattern).ty = ty))
+ (List.combine outputs inst_sg.doutputs)
+ | _ -> (* TODO: trait methods *) ());
(* Retrieve the function id, and register the function call in the context
* if necessary *)
let ctx, func =
@@ -2961,10 +2968,12 @@ let translate_fun_decl (ctx : bs_ctx) (body : S.expression option) : fun_decl =
^ "\n- signature.inputs: "
^ String.concat ", " (List.map (ty_to_string ctx) signature.inputs)
));
- assert (
- List.for_all
- (fun (var, ty) -> (var : var).ty = ty)
- (List.combine inputs signature.inputs));
+ (* TODO: we need to normalize the types *)
+ if !Config.type_check_pure_code then
+ assert (
+ List.for_all
+ (fun (var, ty) -> (var : var).ty = ty)
+ (List.combine inputs signature.inputs));
Some { inputs; inputs_lvs; body }
in