diff options
Diffstat (limited to 'compiler/Substitute.ml')
-rw-r--r-- | compiler/Substitute.ml | 357 |
1 files changed, 357 insertions, 0 deletions
diff --git a/compiler/Substitute.ml b/compiler/Substitute.ml new file mode 100644 index 00000000..5e5858de --- /dev/null +++ b/compiler/Substitute.ml @@ -0,0 +1,357 @@ +(** This file implements various substitution utilities to instantiate types, + function bodies, etc. + *) + +module T = Types +module TU = TypesUtils +module V = Values +module E = Expressions +module A = LlbcAst +module C = Contexts + +(** Substitute types variables and regions in a type. + + TODO: we can reimplement that with visitors. + *) +let rec ty_substitute (rsubst : 'r1 -> 'r2) + (tsubst : T.TypeVarId.id -> 'r2 T.ty) (ty : 'r1 T.ty) : 'r2 T.ty = + let open T in + let subst = ty_substitute rsubst tsubst in + (* helper *) + match ty with + | Adt (def_id, regions, tys) -> + Adt (def_id, List.map rsubst regions, List.map subst tys) + | Array aty -> Array (subst aty) + | Slice sty -> Slice (subst sty) + | Ref (r, ref_ty, ref_kind) -> Ref (rsubst r, subst ref_ty, ref_kind) + (* Below variants: we technically return the same value, but because + one has type ['r1 ty] and the other has type ['r2 ty], we need to + deconstruct then reconstruct *) + | Bool -> Bool + | Char -> Char + | Never -> Never + | Integer int_ty -> Integer int_ty + | Str -> Str + | TypeVar vid -> tsubst vid + +(** Convert an {!T.rty} to an {!T.ety} by erasing the region variables *) +let erase_regions (ty : T.rty) : T.ety = + ty_substitute (fun _ -> T.Erased) (fun vid -> T.TypeVar vid) ty + +(** Generate fresh regions for region variables. + + Return the list of new regions and appropriate substitutions from the + original region variables to the fresh regions. + + TODO: simplify? we only need the subst [T.RegionVarId.id -> T.RegionId.id] + *) +let fresh_regions_with_substs (region_vars : T.region_var list) : + T.RegionId.id list + * (T.RegionVarId.id -> T.RegionId.id) + * (T.RegionVarId.id T.region -> T.RegionId.id T.region) = + (* Generate fresh regions *) + let fresh_region_ids = List.map (fun _ -> C.fresh_region_id ()) region_vars in + (* Generate the map from region var ids to regions *) + let ls = List.combine region_vars fresh_region_ids in + let rid_map = + List.fold_left + (fun mp (k, v) -> T.RegionVarId.Map.add k.T.index v mp) + T.RegionVarId.Map.empty ls + in + (* Generate the substitution from region var id to region *) + let rid_subst id = T.RegionVarId.Map.find id rid_map in + (* Generate the substitution from region to region *) + let rsubst r = + match r with T.Static -> T.Static | T.Var id -> T.Var (rid_subst id) + in + (* Return *) + (fresh_region_ids, rid_subst, rsubst) + +(** Erase the regions in a type and substitute the type variables *) +let erase_regions_substitute_types (tsubst : T.TypeVarId.id -> T.ety) + (ty : 'r T.region T.ty) : T.ety = + let rsubst (_ : 'r T.region) : T.erased_region = T.Erased in + ty_substitute rsubst tsubst ty + +(** Create a region substitution from a list of region variable ids and a list of + regions (with which to substitute the region variable ids *) +let make_region_subst (var_ids : T.RegionVarId.id list) + (regions : 'r T.region list) : T.RegionVarId.id T.region -> 'r T.region = + let ls = List.combine var_ids regions in + let mp = + List.fold_left + (fun mp (k, v) -> T.RegionVarId.Map.add k v mp) + T.RegionVarId.Map.empty ls + in + fun r -> + match r with + | T.Static -> T.Static + | T.Var id -> T.RegionVarId.Map.find id mp + +(** Create a type substitution from a list of type variable ids and a list of + types (with which to substitute the type variable ids) *) +let make_type_subst (var_ids : T.TypeVarId.id list) (tys : 'r T.ty list) : + T.TypeVarId.id -> 'r T.ty = + let ls = List.combine var_ids tys in + let mp = + List.fold_left + (fun mp (k, v) -> T.TypeVarId.Map.add k v mp) + T.TypeVarId.Map.empty ls + in + fun id -> T.TypeVarId.Map.find id mp + +(** Instantiate the type variables in an ADT definition, and return, for + every variant, the list of the types of its fields *) +let type_decl_get_instantiated_variants_fields_rtypes (def : T.type_decl) + (regions : T.RegionId.id T.region list) (types : T.rty list) : + (T.VariantId.id option * T.rty list) list = + let r_subst = + make_region_subst + (List.map (fun x -> x.T.index) def.T.region_params) + regions + in + let ty_subst = + make_type_subst (List.map (fun x -> x.T.index) def.T.type_params) types + in + let (variants_fields : (T.VariantId.id option * T.field list) list) = + match def.T.kind with + | T.Enum variants -> + List.mapi + (fun i v -> (Some (T.VariantId.of_int i), v.T.fields)) + variants + | T.Struct fields -> [ (None, fields) ] + | T.Opaque -> + raise + (Failure + ("Can't retrieve the variants of an opaque type: " + ^ Names.name_to_string def.name)) + in + List.map + (fun (id, fields) -> + ( id, + List.map (fun f -> ty_substitute r_subst ty_subst f.T.field_ty) fields + )) + variants_fields + +(** Instantiate the type variables in an ADT definition, and return the list + of types of the fields for the chosen variant *) +let type_decl_get_instantiated_field_rtypes (def : T.type_decl) + (opt_variant_id : T.VariantId.id option) + (regions : T.RegionId.id T.region list) (types : T.rty list) : T.rty list = + let r_subst = + make_region_subst + (List.map (fun x -> x.T.index) def.T.region_params) + regions + in + let ty_subst = + make_type_subst (List.map (fun x -> x.T.index) def.T.type_params) types + in + let fields = TU.type_decl_get_fields def opt_variant_id in + List.map (fun f -> ty_substitute r_subst ty_subst f.T.field_ty) fields + +(** Return the types of the properly instantiated ADT's variant, provided a + context *) +let ctx_adt_get_instantiated_field_rtypes (ctx : C.eval_ctx) + (def_id : T.TypeDeclId.id) (opt_variant_id : T.VariantId.id option) + (regions : T.RegionId.id T.region list) (types : T.rty list) : T.rty list = + let def = C.ctx_lookup_type_decl ctx def_id in + type_decl_get_instantiated_field_rtypes def opt_variant_id regions types + +(** Return the types of the properly instantiated ADT value (note that + here, ADT is understood in its broad meaning: ADT, assumed value or tuple) *) +let ctx_adt_value_get_instantiated_field_rtypes (ctx : C.eval_ctx) + (adt : V.adt_value) (id : T.type_id) + (region_params : T.RegionId.id T.region list) (type_params : T.rty list) : + T.rty list = + match id with + | T.AdtId id -> + (* Retrieve the types of the fields *) + ctx_adt_get_instantiated_field_rtypes ctx id adt.V.variant_id + region_params type_params + | T.Tuple -> + assert (List.length region_params = 0); + type_params + | T.Assumed aty -> ( + match aty with + | T.Box | T.Vec -> + assert (List.length region_params = 0); + assert (List.length type_params = 1); + type_params + | T.Option -> + assert (List.length region_params = 0); + assert (List.length type_params = 1); + if adt.V.variant_id = Some T.option_some_id then type_params + else if adt.V.variant_id = Some T.option_none_id then [] + else failwith "Unrechable") + +(** Instantiate the type variables in an ADT definition, and return the list + of types of the fields for the chosen variant *) +let type_decl_get_instantiated_field_etypes (def : T.type_decl) + (opt_variant_id : T.VariantId.id option) (types : T.ety list) : T.ety list = + let ty_subst = + make_type_subst (List.map (fun x -> x.T.index) def.T.type_params) types + in + let fields = TU.type_decl_get_fields def opt_variant_id in + List.map + (fun f -> erase_regions_substitute_types ty_subst f.T.field_ty) + fields + +(** Return the types of the properly instantiated ADT's variant, provided a + context *) +let ctx_adt_get_instantiated_field_etypes (ctx : C.eval_ctx) + (def_id : T.TypeDeclId.id) (opt_variant_id : T.VariantId.id option) + (types : T.ety list) : T.ety list = + let def = C.ctx_lookup_type_decl ctx def_id in + type_decl_get_instantiated_field_etypes def opt_variant_id types + +(** Apply a type substitution to a place *) +let place_substitute (_tsubst : T.TypeVarId.id -> T.ety) (p : E.place) : E.place + = + (* There is nothing to do *) + p + +(** Apply a type substitution to an operand *) +let operand_substitute (tsubst : T.TypeVarId.id -> T.ety) (op : E.operand) : + E.operand = + let p_subst = place_substitute tsubst in + match op with + | E.Copy p -> E.Copy (p_subst p) + | E.Move p -> E.Move (p_subst p) + | E.Constant (ety, cv) -> + let rsubst x = x in + E.Constant (ty_substitute rsubst tsubst ety, cv) + +(** Apply a type substitution to an rvalue *) +let rvalue_substitute (tsubst : T.TypeVarId.id -> T.ety) (rv : E.rvalue) : + E.rvalue = + let op_subst = operand_substitute tsubst in + let p_subst = place_substitute tsubst in + match rv with + | E.Use op -> E.Use (op_subst op) + | E.Ref (p, bkind) -> E.Ref (p_subst p, bkind) + | E.UnaryOp (unop, op) -> E.UnaryOp (unop, op_subst op) + | E.BinaryOp (binop, op1, op2) -> + E.BinaryOp (binop, op_subst op1, op_subst op2) + | E.Discriminant p -> E.Discriminant (p_subst p) + | E.Aggregate (kind, ops) -> + let ops = List.map op_subst ops in + let kind = + match kind with + | E.AggregatedTuple -> E.AggregatedTuple + | E.AggregatedOption (variant_id, ty) -> + let rsubst r = r in + E.AggregatedOption (variant_id, ty_substitute rsubst tsubst ty) + | E.AggregatedAdt (def_id, variant_id, regions, tys) -> + let rsubst r = r in + E.AggregatedAdt + ( def_id, + variant_id, + regions, + List.map (ty_substitute rsubst tsubst) tys ) + in + E.Aggregate (kind, ops) + +(** Apply a type substitution to an assertion *) +let assertion_substitute (tsubst : T.TypeVarId.id -> T.ety) (a : A.assertion) : + A.assertion = + { A.cond = operand_substitute tsubst a.A.cond; A.expected = a.A.expected } + +(** Apply a type substitution to a call *) +let call_substitute (tsubst : T.TypeVarId.id -> T.ety) (call : A.call) : A.call + = + let rsubst x = x in + let type_args = List.map (ty_substitute rsubst tsubst) call.A.type_args in + let args = List.map (operand_substitute tsubst) call.A.args in + let dest = place_substitute tsubst call.A.dest in + (* Putting all the paramters on purpose: we want to get a compiler error if + something moves - we may add a field on which we need to apply a substitution *) + { + func = call.A.func; + region_args = call.A.region_args; + A.type_args; + args; + dest; + } + +(** Apply a type substitution to a statement *) +let rec statement_substitute (tsubst : T.TypeVarId.id -> T.ety) + (st : A.statement) : A.statement = + { st with A.content = raw_statement_substitute tsubst st.content } + +and raw_statement_substitute (tsubst : T.TypeVarId.id -> T.ety) + (st : A.raw_statement) : A.raw_statement = + match st with + | A.Assign (p, rvalue) -> + let p = place_substitute tsubst p in + let rvalue = rvalue_substitute tsubst rvalue in + A.Assign (p, rvalue) + | A.AssignGlobal g -> + (* Globals don't have type parameters *) + A.AssignGlobal g + | A.FakeRead p -> + let p = place_substitute tsubst p in + A.FakeRead p + | A.SetDiscriminant (p, vid) -> + let p = place_substitute tsubst p in + A.SetDiscriminant (p, vid) + | A.Drop p -> + let p = place_substitute tsubst p in + A.Drop p + | A.Assert assertion -> + let assertion = assertion_substitute tsubst assertion in + A.Assert assertion + | A.Call call -> + let call = call_substitute tsubst call in + A.Call call + | A.Panic | A.Return | A.Break _ | A.Continue _ | A.Nop -> st + | A.Sequence (st1, st2) -> + A.Sequence + (statement_substitute tsubst st1, statement_substitute tsubst st2) + | A.Switch (op, tgts) -> + A.Switch + (operand_substitute tsubst op, switch_targets_substitute tsubst tgts) + | A.Loop le -> A.Loop (statement_substitute tsubst le) + +(** Apply a type substitution to switch targets *) +and switch_targets_substitute (tsubst : T.TypeVarId.id -> T.ety) + (tgts : A.switch_targets) : A.switch_targets = + match tgts with + | A.If (st1, st2) -> + A.If (statement_substitute tsubst st1, statement_substitute tsubst st2) + | A.SwitchInt (int_ty, tgts, otherwise) -> + let tgts = + List.map (fun (sv, st) -> (sv, statement_substitute tsubst st)) tgts + in + let otherwise = statement_substitute tsubst otherwise in + A.SwitchInt (int_ty, tgts, otherwise) + +(** Apply a type substitution to a function body. Return the local variables + and the body. *) +let fun_body_substitute_in_body (tsubst : T.TypeVarId.id -> T.ety) + (body : A.fun_body) : A.var list * A.statement = + let rsubst r = r in + let locals = + List.map + (fun v -> { v with A.var_ty = ty_substitute rsubst tsubst v.A.var_ty }) + body.A.locals + in + let body = statement_substitute tsubst body.body in + (locals, body) + +(** Substitute a function signature *) +let substitute_signature (asubst : T.RegionGroupId.id -> V.AbstractionId.id) + (rsubst : T.RegionVarId.id -> T.RegionId.id) + (tsubst : T.TypeVarId.id -> T.rty) (sg : A.fun_sig) : A.inst_fun_sig = + let rsubst' (r : T.RegionVarId.id T.region) : T.RegionId.id T.region = + match r with T.Static -> T.Static | T.Var rid -> T.Var (rsubst rid) + in + let inputs = List.map (ty_substitute rsubst' tsubst) sg.A.inputs in + let output = ty_substitute rsubst' tsubst sg.A.output in + let subst_region_group (rg : T.region_var_group) : A.abs_region_group = + let id = asubst rg.id in + let regions = List.map rsubst rg.regions in + let parents = List.map asubst rg.parents in + { id; regions; parents } + in + let regions_hierarchy = List.map subst_region_group sg.A.regions_hierarchy in + { A.regions_hierarchy; inputs; output } |