summaryrefslogtreecommitdiff
path: root/compiler/AssociatedTypes.ml
blob: 4de5382aab44f8837dc8415c0264eb053898f3f3 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
(** This file implements utilities to handle trait associated types, in
    particular with normalization helpers.

    When normalizing a type, we simplify the references to the trait associated
    types, and choose a representative when there are equalities between types
    enforced by local clauses (i.e., clauses of the shape [where Trait1::T = Trait2::U]).
 *)

open Types
open TypesUtils
open Values
open LlbcAst
open Contexts
open Errors
module Subst = Substitute

(** The local logger *)
let log = Logging.associated_types_log

let trait_type_ref_substitute (subst : Subst.subst) (r : trait_type_ref) :
    trait_type_ref =
  let { trait_ref; type_name } = r in
  let trait_ref = Subst.trait_ref_substitute subst trait_ref in
  { trait_ref; type_name }

module TyOrd = struct
  type t = ty

  let compare = compare_ty
  let to_string = show_ty
  let pp_t = pp_ty
  let show_t = show_ty
end

module TyMap = Collections.MakeMap (TyOrd)

let compute_norm_trait_types_from_preds (span : Meta.span option)
    (trait_type_constraints : trait_type_constraint list) : ty TraitTypeRefMap.t
    =
  (* Compute a union-find structure by recursively exploring the predicates and clauses *)
  let norm : ty UnionFind.elem TyMap.t ref = ref TyMap.empty in
  let get_ref (ty : ty) : ty UnionFind.elem =
    match TyMap.find_opt ty !norm with
    | Some r -> r
    | None ->
        let r = UnionFind.make ty in
        norm := TyMap.add ty r !norm;
        r
  in
  let add_trait_type_constraint (c : trait_type_constraint) =
    (* Sanity check: the type constraint can't make use of regions - Remark
       that it would be enough to only visit the field [ty] of the trait type
       constraint, but for safety we visit all the fields *)
    sanity_check_opt_span __FILE__ __LINE__
      (trait_type_constraint_no_regions c)
      span;
    let { trait_ref; type_name; ty } : trait_type_constraint = c in
    let trait_ty = TTraitType (trait_ref, type_name) in
    let trait_ty_ref = get_ref trait_ty in
    let ty_ref = get_ref ty in
    let new_repr = UnionFind.get ty_ref in
    let merged = UnionFind.union trait_ty_ref ty_ref in
    (* Not sure the set operation is necessary, but I want to control which
       representative is chosen *)
    UnionFind.set merged new_repr
  in
  (* Explore the local predicates *)
  List.iter add_trait_type_constraint trait_type_constraints;
  (* TODO: explore the local clauses *)
  (* Compute the norm maps *)
  let rbindings =
    List.map (fun (k, v) -> (k, UnionFind.get v)) (TyMap.bindings !norm)
  in
  (* Filter the keys to keep only the trait type aliases *)
  let rbindings =
    List.filter_map
      (fun (k, v) ->
        match k with
        | TTraitType (trait_ref, type_name) -> Some ({ trait_ref; type_name }, v)
        | _ -> None)
      rbindings
  in
  TraitTypeRefMap.of_list rbindings

let ctx_add_norm_trait_types_from_preds (span : Meta.span) (ctx : eval_ctx)
    (trait_type_constraints : trait_type_constraint list) : eval_ctx =
  let norm_trait_types =
    compute_norm_trait_types_from_preds (Some span) trait_type_constraints
  in
  { ctx with norm_trait_types }

(** A trait instance id refers to a local clause if it only uses the variants:
    [Self], [Clause], [ParentClause], [ItemClause] *)
let rec trait_instance_id_is_local_clause (id : trait_instance_id) : bool =
  match id with
  | Self | Clause _ -> true
  | TraitImpl _ | BuiltinOrAuto _ | TraitRef _ | UnknownTrait _ | FnPointer _
  | Closure _ | Unsolved _ ->
      false
  | ParentClause (id, _, _) | ItemClause (id, _, _, _) ->
      trait_instance_id_is_local_clause id

(** About the conversion functions: for now we need them (TODO: merge ety, rty, etc.),
    but they should be applied to types without regions.
 *)
type norm_ctx = {
  span : Meta.span option;
  norm_trait_types : ty TraitTypeRefMap.t;
  type_decls : type_decl TypeDeclId.Map.t;
  fun_decls : fun_decl FunDeclId.Map.t;
  global_decls : global_decl GlobalDeclId.Map.t;
  trait_decls : trait_decl TraitDeclId.Map.t;
  trait_impls : trait_impl TraitImplId.Map.t;
  type_vars : type_var list;
  const_generic_vars : const_generic_var list;
}

let norm_ctx_to_fmt_env (ctx : norm_ctx) : Print.fmt_env =
  {
    type_decls = ctx.type_decls;
    fun_decls = ctx.fun_decls;
    global_decls = ctx.global_decls;
    trait_decls = ctx.trait_decls;
    trait_impls = ctx.trait_impls;
    regions = [];
    generics =
      {
        TypesUtils.empty_generic_params with
        types = ctx.type_vars;
        const_generics = ctx.const_generic_vars;
        trait_clauses = [];
      };
    locals = [];
  }

let norm_ctx_get_ty_repr (ctx : norm_ctx) (x : trait_type_ref) : ty option =
  TraitTypeRefMap.find_opt x ctx.norm_trait_types

let ty_to_string (ctx : norm_ctx) (ty : ty) : string =
  let ctx = norm_ctx_to_fmt_env ctx in
  Print.Types.ty_to_string ctx ty

let trait_ref_to_string (ctx : norm_ctx) (x : trait_ref) : string =
  let ctx = norm_ctx_to_fmt_env ctx in
  Print.Types.trait_ref_to_string ctx x

let trait_instance_id_to_string (ctx : norm_ctx) (x : trait_instance_id) :
    string =
  let ctx = norm_ctx_to_fmt_env ctx in
  Print.Types.trait_instance_id_to_string ctx x

let generic_args_to_string (ctx : norm_ctx) (x : generic_args) : string =
  let ctx = norm_ctx_to_fmt_env ctx in
  Print.Types.generic_args_to_string ctx x

let generic_params_to_string (ctx : norm_ctx) (x : generic_params) : string =
  let ctx = norm_ctx_to_fmt_env ctx in
  "<"
  ^ String.concat ", " (fst (Print.Types.generic_params_to_strings ctx x))
  ^ ">"

(** Small utility to lookup trait impls, together with a substitution. *)
let norm_ctx_lookup_trait_impl (ctx : norm_ctx) (impl_id : TraitImplId.id)
    (generics : generic_args) : trait_impl * Subst.subst =
  (* Lookup the implementation *)
  let trait_impl = TraitImplId.Map.find impl_id ctx.trait_impls in
  (* The substitution *)
  let tr_self = UnknownTrait __FUNCTION__ in
  let subst =
    Subst.make_subst_from_generics trait_impl.generics generics tr_self
  in
  (* Return *)
  (trait_impl, subst)

let norm_ctx_lookup_trait_impl_ty (ctx : norm_ctx) (impl_id : TraitImplId.id)
    (generics : generic_args) (type_name : string) : ty =
  (* Lookup the implementation *)
  let trait_impl, subst = norm_ctx_lookup_trait_impl ctx impl_id generics in
  (* Lookup the type *)
  let ty = snd (List.assoc type_name trait_impl.types) in
  (* Substitute *)
  Subst.ty_substitute subst ty

let norm_ctx_lookup_trait_impl_parent_clause (ctx : norm_ctx)
    (impl_id : TraitImplId.id) (generics : generic_args)
    (clause_id : TraitClauseId.id) : trait_ref =
  (* Lookup the implementation *)
  let trait_impl, subst = norm_ctx_lookup_trait_impl ctx impl_id generics in
  (* Lookup the clause *)
  let clause = TraitClauseId.nth trait_impl.parent_trait_refs clause_id in
  (* Sanity check: the clause necessarily refers to an impl *)
  let _ = TypesUtils.trait_instance_id_as_trait_impl clause.trait_id in
  (* Substitute *)
  Subst.trait_ref_substitute subst clause

let norm_ctx_lookup_trait_impl_item_clause (ctx : norm_ctx)
    (impl_id : TraitImplId.id) (generics : generic_args) (item_name : string)
    (clause_id : TraitClauseId.id) : trait_ref =
  (* Lookup the implementation *)
  let trait_impl, subst = norm_ctx_lookup_trait_impl ctx impl_id generics in
  (* Lookup the item then its clause *)
  let item = List.assoc item_name trait_impl.types in
  let clause = TraitClauseId.nth (fst item) clause_id in
  (* Sanity check: the clause necessarily refers to an impl *)
  let _ = TypesUtils.trait_instance_id_as_trait_impl clause.trait_id in
  (* Substitute *)
  Subst.trait_ref_substitute subst clause

(** Normalize a type by simplifying the references to trait associated types
    and choosing a representative when there are equalities between types
    enforced by local clauses (i.e., `where Trait1::T = Trait2::U`.

    See the comments for {!norm_ctx_normalize_trait_instance_id}.
  *)
let rec norm_ctx_normalize_ty (ctx : norm_ctx) (ty : ty) : ty =
  log#ldebug (lazy ("norm_ctx_normalize_ty: " ^ ty_to_string ctx ty));
  match ty with
  | TAdt (id, generics) ->
      TAdt (id, norm_ctx_normalize_generic_args ctx generics)
  | TVar _ | TLiteral _ | TNever -> ty
  | TRef (r, ty, rkind) ->
      let ty = norm_ctx_normalize_ty ctx ty in
      TRef (r, ty, rkind)
  | TRawPtr (ty, rkind) ->
      let ty = norm_ctx_normalize_ty ctx ty in
      TRawPtr (ty, rkind)
  | TArrow (regions, inputs, output) ->
      (* TODO: for now it works because we don't support predicates with
         bound regions. If we do support them, we probably need to do
         something smarter here. *)
      let inputs = List.map (norm_ctx_normalize_ty ctx) inputs in
      let output = norm_ctx_normalize_ty ctx output in
      TArrow (regions, inputs, output)
  | TTraitType (trait_ref, type_name) -> (
      log#ldebug
        (lazy
          ("norm_ctx_normalize_ty:\n- trait type: " ^ ty_to_string ctx ty
         ^ "\n- trait_ref: "
          ^ trait_ref_to_string ctx trait_ref
          ^ "\n- raw trait ref:\n" ^ show_trait_ref trait_ref));
      (* Normalize and attempt to project the type from the trait ref *)
      let trait_ref = norm_ctx_normalize_trait_ref ctx trait_ref in
      let ty : ty =
        match trait_ref.trait_id with
        | TraitRef { trait_id = TraitImpl impl_id; generics = ref_generics; _ }
          ->
            cassert_opt_span __FILE__ __LINE__
              (ref_generics = empty_generic_args)
              ctx.span "Higher order trait types are not supported yet";
            log#ldebug
              (lazy
                ("norm_ctx_normalize_ty: trait type: trait ref: "
               ^ ty_to_string ctx ty));
            (* Lookup the type *)
            let ty =
              norm_ctx_lookup_trait_impl_ty ctx impl_id trait_ref.generics
                type_name
            in
            (* Normalize *)
            norm_ctx_normalize_ty ctx ty
        | TraitImpl impl_id ->
            log#ldebug
              (lazy
                ("norm_ctx_normalize_ty (trait impl):\n- trait type: "
               ^ ty_to_string ctx ty ^ "\n- trait_ref: "
                ^ trait_ref_to_string ctx trait_ref
                ^ "\n- raw trait ref:\n" ^ show_trait_ref trait_ref));
            (* This happens. This doesn't come from the substitutions
               performed by Aeneas (the [TraitImpl] would be wrapped in a
               [TraitRef] but from non-normalized traits translated from
               the Rustc AS
               TODO: factor out with the branch above.
            *)
            (* Lookup the type *)
            let ty =
              norm_ctx_lookup_trait_impl_ty ctx impl_id trait_ref.generics
                type_name
            in
            (* Normalize *)
            norm_ctx_normalize_ty ctx ty
        | _ ->
            log#ldebug
              (lazy
                ("norm_ctx_normalize_ty: trait type: not a trait ref: "
               ^ ty_to_string ctx ty ^ "\n- trait_ref: "
                ^ trait_ref_to_string ctx trait_ref
                ^ "\n- raw trait ref:\n" ^ show_trait_ref trait_ref));
            (* We can't project *)
            sanity_check_opt_span __FILE__ __LINE__
              (trait_instance_id_is_local_clause trait_ref.trait_id)
              ctx.span;
            TTraitType (trait_ref, type_name)
      in
      let tr : trait_type_ref = { trait_ref; type_name } in
      (* Lookup the representative, if there is *)
      match norm_ctx_get_ty_repr ctx tr with None -> ty | Some ty -> ty)

(** This returns the normalized trait instance id together with an optional
    reference to a trait **implementation** (the `trait_ref` we return has
    necessarily for instance id a [TraitImpl]).

    We need this in particular to simplify the trait instance ids after we
    performed a substitution.

    Example:
    ========
    {[
      trait Trait {
        type S
      }

      impl TraitImpl for Foo {
        type S = usize
      }

      fn f<T : Trait>(...) -> T::S;

      ...
      let x = f<Foo>[TraitImpl](...);
      (* The return type of the call to f is:
         T::S ~~> TraitImpl::S ~~> usize
       *)
    ]}

    Several remarks:
    - as we do not allow higher-order types (yet) then local clauses (and
      sub-clauses) can't have generic arguments
    - the [TraitRef] case only happens because of substitution, the role of
      the normalization is in particular to eliminate it. Inside a [TraitRef]
      there is necessarily:
      - an id referencing a local (sub-)clause, that is an id using the variants
        [Self], [Clause], [ItemClause] and [ParentClause] exclusively. We can't
        simplify those cases: all we can do is remove the [TraitRef] wrapper
        by leveraging the fact that the generic arguments must be empty.
      - a [TraitImpl]. Note that the [TraitImpl] is necessarily just a [TraitImpl],
        it can't be for instance a [ParentClause(TraitImpl ...)] because the
        trait resolution would then directly reference the implementation
        designated by [ParentClause(TraitImpl ...)] (and same for the other cases).
        In this case we can lookup the trait implementation and recursively project
        over it.
 *)
and norm_ctx_normalize_trait_instance_id (ctx : norm_ctx)
    (id : trait_instance_id) : trait_instance_id * trait_ref option =
  match id with
  | Self -> (id, None)
  | TraitImpl _ ->
      (* The [TraitImpl] shouldn't be inside any projection - we check this
         elsewhere by asserting that whenever we return [None] for the impl
         trait ref, then the id actually refers to a local clause. *)
      (id, None)
  | Clause _ -> (id, None)
  | BuiltinOrAuto _ -> (id, None)
  | ParentClause (inst_id, decl_id, clause_id) -> (
      let inst_id, impl = norm_ctx_normalize_trait_instance_id ctx inst_id in
      (* Check if the inst_id refers to a specific implementation, if yes project *)
      match impl with
      | None ->
          (* This is actually a local clause *)
          sanity_check_opt_span __FILE__ __LINE__
            (trait_instance_id_is_local_clause inst_id)
            ctx.span;
          (ParentClause (inst_id, decl_id, clause_id), None)
      | Some impl ->
          (* We figure out the parent clause by doing the following:
             {[
               // The implementation we are looking at
               impl Impl1 : Trait1 { ... }

               // Check the trait it implements
               trait Trait1 : ParentTrait1 + ParentTrait2 { ... }
                              ^^^^^^^^^^^^^^^^^^^^^^^^^^^
                              those are the parent clauses
             ]}
          *)
          (* Lookup the clause *)
          let impl_id =
            TypesUtils.trait_instance_id_as_trait_impl impl.trait_id
          in
          let clause =
            norm_ctx_lookup_trait_impl_parent_clause ctx impl_id impl.generics
              clause_id
          in
          (* Normalize the clause *)
          let clause = norm_ctx_normalize_trait_ref ctx clause in
          (TraitRef clause, Some clause))
  | ItemClause (inst_id, decl_id, item_name, clause_id) -> (
      let inst_id, impl = norm_ctx_normalize_trait_instance_id ctx inst_id in
      (* Check if the inst_id refers to a specific implementation, if yes project *)
      match impl with
      | None ->
          (* This is actually a local clause *)
          sanity_check_opt_span __FILE__ __LINE__
            (trait_instance_id_is_local_clause inst_id)
            ctx.span;
          (ItemClause (inst_id, decl_id, item_name, clause_id), None)
      | Some impl ->
          (* We figure out the item clause by doing the following:
             {[
               // The implementation we are looking at
               impl Impl1 : Trait1<R> {
                  type S = ...
                     with Impl2 : Trait2 ... // Instances satisfying the declared bounds
                          ^^^^^^^^^^^^^^^^^^
                      Lookup the clause from here
               }
             ]}
          *)
          (* Lookup the impl *)
          let impl_id =
            TypesUtils.trait_instance_id_as_trait_impl impl.trait_id
          in
          let clause =
            norm_ctx_lookup_trait_impl_item_clause ctx impl_id impl.generics
              item_name clause_id
          in
          (* Normalize the clause *)
          let clause = norm_ctx_normalize_trait_ref ctx clause in
          (TraitRef clause, Some clause))
  | TraitRef { trait_id = TraitImpl trait_id; generics; trait_decl_ref } ->
      (* We can't simplify the id *yet* : we will simplify it when projecting.
         However, we have an implementation to return *)
      (* Normalize the generics *)
      let generics = norm_ctx_normalize_generic_args ctx generics in
      let trait_decl_ref =
        norm_ctx_normalize_trait_decl_ref ctx trait_decl_ref
      in
      let trait_ref : trait_ref =
        { trait_id = TraitImpl trait_id; generics; trait_decl_ref }
      in
      (TraitRef trait_ref, Some trait_ref)
  | TraitRef trait_ref ->
      (* The trait instance id necessarily refers to a local sub-clause. We
         can't project over it and can only peel off the [TraitRef] wrapper *)
      sanity_check_opt_span __FILE__ __LINE__
        (trait_instance_id_is_local_clause trait_ref.trait_id)
        ctx.span;
      sanity_check_opt_span __FILE__ __LINE__
        (trait_ref.generics = empty_generic_args)
        ctx.span;
      (trait_ref.trait_id, None)
  | FnPointer ty ->
      let ty = norm_ctx_normalize_ty ctx ty in
      (* TODO: we might want to return the ref to the function pointer,
         in order to later normalize a call to this function pointer *)
      (FnPointer ty, None)
  | Closure (fid, generics) ->
      let generics = norm_ctx_normalize_generic_args ctx generics in
      (Closure (fid, generics), None)
  | Unsolved _ | UnknownTrait _ ->
      (* This is actually an error case *)
      (id, None)

and norm_ctx_normalize_generic_args (ctx : norm_ctx) (generics : generic_args) :
    generic_args =
  let { regions; types; const_generics; trait_refs } = generics in
  let types = List.map (norm_ctx_normalize_ty ctx) types in
  let trait_refs = List.map (norm_ctx_normalize_trait_ref ctx) trait_refs in
  { regions; types; const_generics; trait_refs }

and norm_ctx_normalize_trait_ref (ctx : norm_ctx) (trait_ref : trait_ref) :
    trait_ref =
  log#ldebug
    (lazy
      ("norm_ctx_normalize_trait_ref: "
      ^ trait_ref_to_string ctx trait_ref
      ^ "\n- raw trait ref:\n" ^ show_trait_ref trait_ref));
  let { trait_id; generics; trait_decl_ref } = trait_ref in
  (* Check if the id is an impl, otherwise normalize it *)
  let trait_id, norm_trait_ref =
    norm_ctx_normalize_trait_instance_id ctx trait_id
  in
  match norm_trait_ref with
  | None ->
      log#ldebug
        (lazy
          ("norm_ctx_normalize_trait_ref: no norm: "
          ^ trait_instance_id_to_string ctx trait_id));
      let generics = norm_ctx_normalize_generic_args ctx generics in
      let trait_decl_ref =
        norm_ctx_normalize_trait_decl_ref ctx trait_decl_ref
      in
      { trait_id; generics; trait_decl_ref }
  | Some trait_ref ->
      log#ldebug
        (lazy
          ("norm_ctx_normalize_trait_ref: normalized to: "
          ^ trait_ref_to_string ctx trait_ref));
      sanity_check_opt_span __FILE__ __LINE__
        (generics = empty_generic_args)
        ctx.span;
      trait_ref

(* Not sure this one is really necessary *)
and norm_ctx_normalize_trait_decl_ref (ctx : norm_ctx)
    (trait_decl_ref : trait_decl_ref) : trait_decl_ref =
  let { trait_decl_id; decl_generics } = trait_decl_ref in
  let decl_generics = norm_ctx_normalize_generic_args ctx decl_generics in
  { trait_decl_id; decl_generics }

let norm_ctx_normalize_trait_type_constraint (ctx : norm_ctx)
    (ttc : trait_type_constraint) : trait_type_constraint =
  let { trait_ref; type_name; ty } : trait_type_constraint = ttc in
  let trait_ref = norm_ctx_normalize_trait_ref ctx trait_ref in
  let ty = norm_ctx_normalize_ty ctx ty in
  { trait_ref; type_name; ty }

let mk_norm_ctx (span : Meta.span) (ctx : eval_ctx) : norm_ctx =
  {
    span = Some span;
    norm_trait_types = ctx.norm_trait_types;
    type_decls = ctx.type_ctx.type_decls;
    fun_decls = ctx.fun_ctx.fun_decls;
    global_decls = ctx.global_ctx.global_decls;
    trait_decls = ctx.trait_decls_ctx.trait_decls;
    trait_impls = ctx.trait_impls_ctx.trait_impls;
    type_vars = ctx.type_vars;
    const_generic_vars = ctx.const_generic_vars;
  }

let ctx_normalize_ty (span : Meta.span) (ctx : eval_ctx) (ty : ty) : ty =
  norm_ctx_normalize_ty (mk_norm_ctx span ctx) ty

(** Normalize a type and erase the regions at the same time *)
let ctx_normalize_erase_ty (span : Meta.span) (ctx : eval_ctx) (ty : ty) : ty =
  let ty = ctx_normalize_ty span ctx ty in
  Subst.erase_regions ty

let ctx_normalize_trait_type_constraint (span : Meta.span) (ctx : eval_ctx)
    (ttc : trait_type_constraint) : trait_type_constraint =
  norm_ctx_normalize_trait_type_constraint (mk_norm_ctx span ctx) ttc

(** Same as [type_decl_get_instantiated_variants_fields_types] but normalizes the types *)
let type_decl_get_inst_norm_variants_fields_rtypes (span : Meta.span)
    (ctx : eval_ctx) (def : type_decl) (generics : generic_args) :
    (VariantId.id option * ty list) list =
  let res =
    Subst.type_decl_get_instantiated_variants_fields_types def generics
  in
  List.map
    (fun (variant_id, types) ->
      (variant_id, List.map (ctx_normalize_ty span ctx) types))
    res

(** Same as [type_decl_get_instantiated_field_types] but normalizes the types *)
let type_decl_get_inst_norm_field_rtypes (span : Meta.span) (ctx : eval_ctx)
    (def : type_decl) (opt_variant_id : VariantId.id option)
    (generics : generic_args) : ty list =
  let types =
    Subst.type_decl_get_instantiated_field_types def opt_variant_id generics
  in
  List.map (ctx_normalize_ty span ctx) types

(** Same as [ctx_adt_value_get_instantiated_field_rtypes] but normalizes the types *)
let ctx_adt_value_get_inst_norm_field_rtypes (span : Meta.span) (ctx : eval_ctx)
    (adt : adt_value) (id : type_id) (generics : generic_args) : ty list =
  let types =
    Subst.ctx_adt_value_get_instantiated_field_types span ctx adt id generics
  in
  List.map (ctx_normalize_ty span ctx) types

(** Same as [ctx_adt_value_get_instantiated_field_types] but normalizes the types
    and erases the regions. *)
let type_decl_get_inst_norm_field_etypes (span : Meta.span) (ctx : eval_ctx)
    (def : type_decl) (opt_variant_id : VariantId.id option)
    (generics : generic_args) : ty list =
  let types =
    Subst.type_decl_get_instantiated_field_types def opt_variant_id generics
  in
  let types = List.map (ctx_normalize_ty span ctx) types in
  List.map Subst.erase_regions types

(** Same as [ctx_adt_get_instantiated_field_types] but normalizes the types and
    erases the regions. *)
let ctx_adt_get_inst_norm_field_etypes (span : Meta.span) (ctx : eval_ctx)
    (def_id : TypeDeclId.id) (opt_variant_id : VariantId.id option)
    (generics : generic_args) : ty list =
  let types =
    Subst.ctx_adt_get_instantiated_field_types ctx def_id opt_variant_id
      generics
  in
  let types = List.map (ctx_normalize_ty span ctx) types in
  List.map Subst.erase_regions types

(** Same as [substitute_signature] but normalizes the types *)
let ctx_subst_norm_signature (span : Meta.span) (ctx : eval_ctx)
    (asubst : RegionGroupId.id -> AbstractionId.id)
    (r_subst : RegionVarId.id -> RegionId.id) (ty_subst : TypeVarId.id -> ty)
    (cg_subst : ConstGenericVarId.id -> const_generic)
    (tr_subst : TraitClauseId.id -> trait_instance_id)
    (tr_self : trait_instance_id) (sg : fun_sig)
    (regions_hierarchy : region_var_groups) : inst_fun_sig =
  let sg =
    Subst.substitute_signature asubst r_subst ty_subst cg_subst tr_subst tr_self
      sg regions_hierarchy
  in
  let { regions_hierarchy; inputs; output; trait_type_constraints } = sg in
  let inputs = List.map (ctx_normalize_ty span ctx) inputs in
  let output = ctx_normalize_ty span ctx output in
  let trait_type_constraints =
    List.map
      (ctx_normalize_trait_type_constraint span ctx)
      trait_type_constraints
  in
  { regions_hierarchy; inputs; output; trait_type_constraints }