From d556b2439ad858fbbf612f433d25363a8f4a7c83 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Wed, 13 Sep 2023 18:43:23 +0200 Subject: Fix more issues --- compiler/AssociatedTypes.ml | 90 ++++++++++++++++++++++++++++++++++++++++++--- compiler/Config.ml | 8 ++++ compiler/Print.ml | 22 +++++++++++ compiler/SymbolicToPure.ml | 77 +++++++++++++++++++++----------------- 4 files changed, 157 insertions(+), 40 deletions(-) diff --git a/compiler/AssociatedTypes.ml b/compiler/AssociatedTypes.ml index c4a9538d..bce0fb11 100644 --- a/compiler/AssociatedTypes.ml +++ b/compiler/AssociatedTypes.ml @@ -15,6 +15,7 @@ module C = Contexts module Subst = Substitute module L = Logging module UF = UnionFind +module PA = Print.EvalCtxLlbcAst (** The local logger *) let log = L.associated_types_log @@ -111,6 +112,10 @@ type 'r norm_ctx = { get_ty_repr : 'r C.trait_type_ref -> 'r T.ty option; convert_ety : T.ety -> 'r T.ty; convert_etrait_ref : T.etrait_ref -> 'r T.trait_ref; + ty_to_string : 'r T.ty -> string; + trait_ref_to_string : 'r T.trait_ref -> string; + trait_instance_id_to_string : 'r T.trait_instance_id -> string; + pp_r : Format.formatter -> 'r -> unit; } (** Normalize a type by simplyfying the references to trait associated types @@ -118,6 +123,7 @@ type 'r norm_ctx = { enforced by local clauses (i.e., `where Trait1::T = Trait2::U`. *) let rec ctx_normalize_ty : 'r. 'r norm_ctx -> 'r T.ty -> 'r T.ty = fun ctx ty -> + log#ldebug (lazy ("ctx_normalize_ty: " ^ ctx.ty_to_string ty)); match ty with | T.Adt (id, generics) -> Adt (id, ctx_normalize_generic_args ctx generics) | TypeVar _ | Literal _ | Never -> ty @@ -125,19 +131,56 @@ let rec ctx_normalize_ty : 'r. 'r norm_ctx -> 'r T.ty -> 'r T.ty = let ty = ctx_normalize_ty ctx ty in T.Ref (r, ty, rkind) | TraitType (trait_ref, generics, type_name) -> ( + log#ldebug + (lazy + ("ctx_normalize_ty: trait type: " ^ ctx.ty_to_string ty + ^ "\n- trait_ref: " + ^ ctx.trait_ref_to_string trait_ref + ^ "\n- raw trait ref: " + ^ T.show_trait_ref ctx.pp_r trait_ref)); (* Normalize and attempt to project the type from the trait ref *) let trait_ref = ctx_normalize_trait_ref ctx trait_ref in let generics = ctx_normalize_generic_args ctx generics in let ty : 'r T.ty = match trait_ref.trait_id with - | T.TraitRef { T.trait_id = T.TraitImpl impl_id; generics; _ } -> + | T.TraitRef + { T.trait_id = T.TraitImpl impl_id; generics = ref_generics; _ } -> + assert (ref_generics = TypesUtils.mk_empty_generic_args); + log#ldebug + (lazy + ("ctx_normalize_ty: trait type: trait ref: " + ^ ctx.ty_to_string ty)); (* Lookup the implementation *) let trait_impl = C.ctx_lookup_trait_impl ctx.ctx impl_id in (* Lookup the type *) let ty = snd (List.assoc type_name trait_impl.types) in (* Annoying: convert etype to an stype - TODO: hwo to avoid that? *) let ty : T.sty = TypesUtils.ety_no_regions_to_gr_ty ty in - (* Substitute - annoying: we can't use *) + (* Substitute *) + let tr_self = T.UnknownTrait __FUNCTION__ in + let subst = + Subst.make_subst_from_generics_no_regions trait_impl.generics + generics tr_self + in + let ty = Subst.ty_substitute subst ty in + (* Reconvert *) + let ty : 'r T.ty = ctx.convert_ety (Subst.erase_regions ty) in + (* Normalize *) + ctx_normalize_ty ctx ty + | T.TraitImpl impl_id -> + (* This happens. This doesn't come from the substituations + performed by Aeneas (the [TraitImpl] would be wrapped in a + [TraitRef] but from non-normalized traits translated from + the Rustc AST. + TODO: factor out with the branch above. + *) + (* Lookup the implementation *) + let trait_impl = C.ctx_lookup_trait_impl ctx.ctx impl_id in + (* Lookup the type *) + let ty = snd (List.assoc type_name trait_impl.types) in + (* Annoying: convert etype to an stype - TODO: hwo to avoid that? *) + let ty : T.sty = TypesUtils.ety_no_regions_to_gr_ty ty in + (* Substitute *) let tr_self = T.UnknownTrait __FUNCTION__ in let subst = Subst.make_subst_from_generics_no_regions trait_impl.generics @@ -149,6 +192,13 @@ let rec ctx_normalize_ty : 'r. 'r norm_ctx -> 'r T.ty -> 'r T.ty = (* Normalize *) ctx_normalize_ty ctx ty | _ -> + log#ldebug + (lazy + ("ctx_normalize_ty: trait type: not a trait ref: " + ^ ctx.ty_to_string ty ^ "\n- trait_ref: " + ^ ctx.trait_ref_to_string trait_ref + ^ "\n- raw trait ref: " + ^ T.show_trait_ref ctx.pp_r trait_ref)); (* We can't project *) assert (trait_instance_id_is_local_clause trait_ref.trait_id); T.TraitType (trait_ref, generics, type_name) @@ -307,11 +357,31 @@ and ctx_normalize_generic_args (ctx : 'r norm_ctx) and ctx_normalize_trait_ref (ctx : 'r norm_ctx) (trait_ref : 'r T.trait_ref) : 'r T.trait_ref = + log#ldebug + (lazy + ("ctx_normalize_trait_ref: " + ^ ctx.trait_ref_to_string trait_ref + ^ "\n- raw trait ref:\n" + ^ T.show_trait_ref ctx.pp_r trait_ref)); let { T.trait_id; generics; trait_decl_ref } = trait_ref in - let trait_id, _ = ctx_normalize_trait_instance_id ctx trait_id in - let generics = ctx_normalize_generic_args ctx generics in - let trait_decl_ref = ctx_normalize_trait_decl_ref ctx trait_decl_ref in - { T.trait_id; generics; trait_decl_ref } + (* Check if the id is an impl, otherwise normalize it *) + let trait_id, norm_trait_ref = ctx_normalize_trait_instance_id ctx trait_id in + match norm_trait_ref with + | None -> + log#ldebug + (lazy + ("ctx_normalize_trait_ref: no norm: " + ^ ctx.trait_instance_id_to_string trait_id)); + let generics = ctx_normalize_generic_args ctx generics in + let trait_decl_ref = ctx_normalize_trait_decl_ref ctx trait_decl_ref in + { T.trait_id; generics; trait_decl_ref } + | Some trait_ref -> + log#ldebug + (lazy + ("ctx_normalize_trait_ref: normalized to: " + ^ ctx.trait_ref_to_string trait_ref)); + assert (generics = TypesUtils.mk_empty_generic_args); + trait_ref (* Not sure this one is really necessary *) and ctx_normalize_trait_decl_ref (ctx : 'r norm_ctx) @@ -335,6 +405,10 @@ let mk_rnorm_ctx (ctx : C.eval_ctx) : T.RegionId.id T.region norm_ctx = get_ty_repr; convert_ety = TypesUtils.ety_no_regions_to_rty; convert_etrait_ref = TypesUtils.etrait_ref_no_regions_to_gr_trait_ref; + ty_to_string = PA.rty_to_string ctx; + trait_ref_to_string = PA.rtrait_ref_to_string ctx; + trait_instance_id_to_string = PA.rtrait_instance_id_to_string ctx; + pp_r = T.pp_region T.pp_region_id; } let mk_enorm_ctx (ctx : C.eval_ctx) : T.erased_region norm_ctx = @@ -344,6 +418,10 @@ let mk_enorm_ctx (ctx : C.eval_ctx) : T.erased_region norm_ctx = get_ty_repr; convert_ety = (fun x -> x); convert_etrait_ref = (fun x -> x); + ty_to_string = PA.ety_to_string ctx; + trait_ref_to_string = PA.etrait_ref_to_string ctx; + trait_instance_id_to_string = PA.etrait_instance_id_to_string ctx; + pp_r = T.pp_erased_region; } let ctx_normalize_rty (ctx : C.eval_ctx) (ty : T.rty) : T.rty = diff --git a/compiler/Config.ml b/compiler/Config.ml index ccbb4c75..508746d9 100644 --- a/compiler/Config.ml +++ b/compiler/Config.ml @@ -331,3 +331,11 @@ let record_fields_short_names = ref false and to account for type constraints (like [fn f(...) where T::bar = usize]). *) let parameterize_trait_types = ref false + +(** For sanity check: type check the generated pure code (activates checks in + several places). + + TODO: deactivated for now because we need to implement the normalization of + trait associated types in the pure code. + *) +let type_check_pure_code = ref false diff --git a/compiler/Print.ml b/compiler/Print.ml index 92743bc1..93a1f970 100644 --- a/compiler/Print.ml +++ b/compiler/Print.ml @@ -651,6 +651,28 @@ module EvalCtxLlbcAst = struct let fmt = PC.ctx_to_rtype_formatter fmt in PT.rty_to_string fmt t + let etrait_ref_to_string (ctx : C.eval_ctx) (x : T.etrait_ref) : string = + let fmt = PC.eval_ctx_to_ctx_formatter ctx in + let fmt = PC.ctx_to_etype_formatter fmt in + PT.etrait_ref_to_string fmt x + + let rtrait_ref_to_string (ctx : C.eval_ctx) (x : T.rtrait_ref) : string = + let fmt = PC.eval_ctx_to_ctx_formatter ctx in + let fmt = PC.ctx_to_rtype_formatter fmt in + PT.rtrait_ref_to_string fmt x + + let etrait_instance_id_to_string (ctx : C.eval_ctx) (x : T.etrait_instance_id) + : string = + let fmt = PC.eval_ctx_to_ctx_formatter ctx in + let fmt = PC.ctx_to_etype_formatter fmt in + PT.etrait_instance_id_to_string fmt x + + let rtrait_instance_id_to_string (ctx : C.eval_ctx) (x : T.rtrait_instance_id) + : string = + let fmt = PC.eval_ctx_to_ctx_formatter ctx in + let fmt = PC.ctx_to_rtype_formatter fmt in + PT.rtrait_instance_id_to_string fmt x + let egeneric_args_to_string (ctx : C.eval_ctx) (x : T.egeneric_args) : string = let fmt = PC.eval_ctx_to_ctx_formatter ctx in diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml index 4c5b99c3..2e0e9862 100644 --- a/compiler/SymbolicToPure.ml +++ b/compiler/SymbolicToPure.ml @@ -687,8 +687,9 @@ let type_check_pattern (ctx : bs_ctx) (v : typed_pattern) : unit = () let type_check_texpression (ctx : bs_ctx) (e : texpression) : unit = - let ctx = mk_type_check_ctx ctx in - PureTypeCheck.check_texpression ctx e + if !Config.type_check_pure_code then + let ctx = mk_type_check_ctx ctx in + PureTypeCheck.check_texpression ctx e let translate_fun_id_or_trait_method_ref (ctx : bs_ctx) (id : A.fun_id_or_trait_method_ref) : fun_id_or_trait_method_ref = @@ -1817,9 +1818,11 @@ and translate_end_abstraction_synth_input (ectx : C.eval_ctx) (abs : V.abs) (* Group the two lists *) let variables_values = List.combine given_back_variables consumed_values in (* Sanity check: the two lists match (same types) *) - List.iter - (fun (var, v) -> assert ((var : var).ty = (v : texpression).ty)) - variables_values; + (* TODO: normalize the types *) + if !Config.type_check_pure_code then + List.iter + (fun (var, v) -> assert ((var : var).ty = (v : texpression).ty)) + variables_values; (* Translate the next expression *) let next_e = translate_expression e ctx in (* Generate the assignemnts *) @@ -1892,31 +1895,35 @@ and translate_end_abstraction_fun_call (ectx : C.eval_ctx) (abs : V.abs) | Some nstate -> mk_simpl_tuple_pattern [ nstate; output ] in (* Sanity check: there is the proper number of inputs and outputs, and they have the proper type *) - let _ = - let inst_sg = get_instantiated_fun_sig fun_id (Some rg_id) generics ctx in - log#ldebug - (lazy - ("\n- fun_id: " ^ A.show_fun_id fun_id ^ "\n- inputs (" - ^ string_of_int (List.length inputs) - ^ "): " - ^ String.concat ", " (List.map (texpression_to_string ctx) inputs) - ^ "\n- inst_sg.inputs (" - ^ string_of_int (List.length inst_sg.inputs) - ^ "): " - ^ String.concat ", " (List.map (ty_to_string ctx) inst_sg.inputs))); - List.iter - (fun (x, ty) -> assert ((x : texpression).ty = ty)) - (List.combine inputs inst_sg.inputs); - log#ldebug - (lazy - ("\n- outputs: " - ^ string_of_int (List.length outputs) - ^ "\n- expected outputs: " - ^ string_of_int (List.length inst_sg.doutputs))); - List.iter - (fun (x, ty) -> assert ((x : typed_pattern).ty = ty)) - (List.combine outputs inst_sg.doutputs) - in + (if (* TODO: normalize the types *) !Config.type_check_pure_code then + match fun_id with + | A.FunId fun_id -> + let inst_sg = + get_instantiated_fun_sig fun_id (Some rg_id) generics ctx + in + log#ldebug + (lazy + ("\n- fun_id: " ^ A.show_fun_id fun_id ^ "\n- inputs (" + ^ string_of_int (List.length inputs) + ^ "): " + ^ String.concat ", " (List.map (texpression_to_string ctx) inputs) + ^ "\n- inst_sg.inputs (" + ^ string_of_int (List.length inst_sg.inputs) + ^ "): " + ^ String.concat ", " (List.map (ty_to_string ctx) inst_sg.inputs))); + List.iter + (fun (x, ty) -> assert ((x : texpression).ty = ty)) + (List.combine inputs inst_sg.inputs); + log#ldebug + (lazy + ("\n- outputs: " + ^ string_of_int (List.length outputs) + ^ "\n- expected outputs: " + ^ string_of_int (List.length inst_sg.doutputs))); + List.iter + (fun (x, ty) -> assert ((x : typed_pattern).ty = ty)) + (List.combine outputs inst_sg.doutputs) + | _ -> (* TODO: trait methods *) ()); (* Retrieve the function id, and register the function call in the context * if necessary *) let ctx, func = @@ -2961,10 +2968,12 @@ let translate_fun_decl (ctx : bs_ctx) (body : S.expression option) : fun_decl = ^ "\n- signature.inputs: " ^ String.concat ", " (List.map (ty_to_string ctx) signature.inputs) )); - assert ( - List.for_all - (fun (var, ty) -> (var : var).ty = ty) - (List.combine inputs signature.inputs)); + (* TODO: we need to normalize the types *) + if !Config.type_check_pure_code then + assert ( + List.for_all + (fun (var, ty) -> (var : var).ty = ty) + (List.combine inputs signature.inputs)); Some { inputs; inputs_lvs; body } in -- cgit v1.2.3