From 50af296306bfee9f0b127dde8abe5fb0ec1b0acb Mon Sep 17 00:00:00 2001 From: Son Ho Date: Tue, 1 Aug 2023 11:16:06 +0200 Subject: Start adding support for const generics --- compiler/Assumed.ml | 11 ++ compiler/Contexts.ml | 8 +- compiler/Print.ml | 35 ++++- compiler/PrintPure.ml | 18 ++- compiler/Pure.ml | 24 +--- compiler/PureTypeCheck.ml | 14 +- compiler/PureUtils.ml | 8 +- compiler/Substitute.ml | 320 ++++++++++++++++++++--------------------- compiler/SynthesizeSymbolic.ml | 8 +- compiler/TypesAnalysis.ml | 11 +- compiler/Values.ml | 15 +- 11 files changed, 246 insertions(+), 226 deletions(-) diff --git a/compiler/Assumed.ml b/compiler/Assumed.ml index e751d0ba..2fbb9044 100644 --- a/compiler/Assumed.ml +++ b/compiler/Assumed.ml @@ -53,6 +53,8 @@ module Sig = struct (** Type parameter [T] of id 0 *) let type_param_0 : T.type_var = { T.index = tvar_id_0; name = "T" } + let empty_const_generic_params : T.const_generic_var list = [] + let mk_ref_ty (r : T.RegionVarId.id T.region) (ty : T.sty) (is_mut : bool) : T.sty = let ref_kind = if is_mut then T.Mut else T.Shared in @@ -73,6 +75,7 @@ module Sig = struct num_early_bound_regions = 0; regions_hierarchy; type_params; + const_generic_params = empty_const_generic_params; inputs; output; } @@ -84,6 +87,7 @@ module Sig = struct num_early_bound_regions = 0; regions_hierarchy = []; type_params = [ type_param_0 ] (* *); + const_generic_params = empty_const_generic_params; inputs = [ tvar_0 (* T *) ]; output = mk_box_ty tvar_0 (* Box *); } @@ -95,6 +99,7 @@ module Sig = struct num_early_bound_regions = 0; regions_hierarchy = []; type_params = [ type_param_0 ] (* *); + const_generic_params = empty_const_generic_params; inputs = [ mk_box_ty tvar_0 (* Box *) ]; output = mk_unit_ty (* () *); } @@ -112,6 +117,7 @@ module Sig = struct num_early_bound_regions = 0; regions_hierarchy; type_params = [ type_param_0 ] (* *); + const_generic_params = empty_const_generic_params; inputs = [ mk_ref_ty rvar_0 (mk_box_ty tvar_0) is_mut (* &'a (mut) Box *) ]; output = mk_ref_ty rvar_0 tvar_0 is_mut (* &'a (mut) T *); @@ -135,6 +141,7 @@ module Sig = struct num_early_bound_regions = 0; regions_hierarchy; type_params; + const_generic_params = empty_const_generic_params; inputs; output; } @@ -157,6 +164,7 @@ module Sig = struct num_early_bound_regions = 0; regions_hierarchy; type_params; + const_generic_params = empty_const_generic_params; inputs; output; } @@ -180,6 +188,7 @@ module Sig = struct num_early_bound_regions = 0; regions_hierarchy; type_params; + const_generic_params = empty_const_generic_params; inputs; output; } @@ -199,6 +208,7 @@ module Sig = struct num_early_bound_regions = 0; regions_hierarchy; type_params; + const_generic_params = empty_const_generic_params; inputs; output; } @@ -223,6 +233,7 @@ module Sig = struct num_early_bound_regions = 0; regions_hierarchy; type_params; + const_generic_params = empty_const_generic_params; inputs; output; } diff --git a/compiler/Contexts.ml b/compiler/Contexts.ml index a425d42b..2ca5653d 100644 --- a/compiler/Contexts.ml +++ b/compiler/Contexts.ml @@ -12,7 +12,8 @@ open Identifiers in the environment, because they contain borrows for instance, typically because they might be overwritten during an assignment. *) -module DummyVarId = IdGen () +module DummyVarId = +IdGen () type dummy_var_id = DummyVarId.id [@@deriving show, ord] @@ -261,6 +262,7 @@ type eval_ctx = { global_context : global_context; region_groups : RegionGroupId.id list; type_vars : type_var list; + const_generic_vars : const_generic_var list; env : env; ended_regions : RegionId.Set.t; } @@ -269,6 +271,10 @@ type eval_ctx = { let lookup_type_var (ctx : eval_ctx) (vid : TypeVarId.id) : type_var = TypeVarId.nth ctx.type_vars vid +let lookup_const_generic_var (ctx : eval_ctx) (vid : ConstGenericVarId.id) : + const_generic_var = + ConstGenericVarId.nth ctx.const_generic_vars vid + (** Lookup a variable in the current frame *) let env_lookup_var (env : env) (vid : VarId.id) : var_binder * typed_value = (* We take care to stop at the end of current frame: different variables diff --git a/compiler/Print.ml b/compiler/Print.ml index f544c0db..23cebd4c 100644 --- a/compiler/Print.ml +++ b/compiler/Print.ml @@ -19,6 +19,8 @@ module Values = struct r_to_string : T.RegionId.id -> string; type_var_id_to_string : T.TypeVarId.id -> string; type_decl_id_to_string : T.TypeDeclId.id -> string; + const_generic_var_id_to_string : T.ConstGenericVarId.id -> string; + global_decl_id_to_string : T.GlobalDeclId.id -> string; adt_variant_to_string : T.TypeDeclId.id -> T.VariantId.id -> string; var_id_to_string : E.VarId.id -> string; adt_field_names : @@ -30,6 +32,8 @@ module Values = struct PT.r_to_string = PT.erased_region_to_string; PT.type_var_id_to_string = fmt.type_var_id_to_string; PT.type_decl_id_to_string = fmt.type_decl_id_to_string; + PT.const_generic_var_id_to_string = fmt.const_generic_var_id_to_string; + PT.global_decl_id_to_string = fmt.global_decl_id_to_string; } let value_to_rtype_formatter (fmt : value_formatter) : PT.rtype_formatter = @@ -37,6 +41,8 @@ module Values = struct PT.r_to_string = PT.region_to_string fmt.r_to_string; PT.type_var_id_to_string = fmt.type_var_id_to_string; PT.type_decl_id_to_string = fmt.type_decl_id_to_string; + PT.const_generic_var_id_to_string = fmt.const_generic_var_id_to_string; + PT.global_decl_id_to_string = fmt.global_decl_id_to_string; } let value_to_stype_formatter (fmt : value_formatter) : PT.stype_formatter = @@ -44,6 +50,8 @@ module Values = struct PT.r_to_string = PT.region_to_string fmt.rvar_to_string; PT.type_var_id_to_string = fmt.type_var_id_to_string; PT.type_decl_id_to_string = fmt.type_decl_id_to_string; + PT.const_generic_var_id_to_string = fmt.const_generic_var_id_to_string; + PT.global_decl_id_to_string = fmt.global_decl_id_to_string; } let var_id_to_string (id : E.VarId.id) : string = @@ -72,16 +80,16 @@ module Values = struct string = let ty_fmt : PT.etype_formatter = value_to_etype_formatter fmt in match v.value with - | Primitive cv -> PPV.primitive_value_to_string cv + | Primitive cv -> PPV.literal_to_string cv | Adt av -> ( let field_values = List.map (typed_value_to_string fmt) av.field_values in match v.ty with - | T.Adt (T.Tuple, _, _) -> + | T.Adt (T.Tuple, _, _, _) -> (* Tuple *) "(" ^ String.concat ", " field_values ^ ")" - | T.Adt (T.AdtId def_id, _, _) -> + | T.Adt (T.AdtId def_id, _, _, _) -> (* "Regular" ADT *) let adt_ident = match av.variant_id with @@ -103,7 +111,7 @@ module Values = struct let field_values = String.concat " " field_values in adt_ident ^ " { " ^ field_values ^ " }" else adt_ident - | T.Adt (T.Assumed aty, _, _) -> ( + | T.Adt (T.Assumed aty, _, _, _) -> ( (* Assumed type *) match (aty, field_values) with | Box, [ bv ] -> "@Box(" ^ bv ^ ")" @@ -188,10 +196,10 @@ module Values = struct List.map (typed_avalue_to_string fmt) av.field_values in match v.ty with - | T.Adt (T.Tuple, _, _) -> + | T.Adt (T.Tuple, _, _, _) -> (* Tuple *) "(" ^ String.concat ", " field_values ^ ")" - | T.Adt (T.AdtId def_id, _, _) -> + | T.Adt (T.AdtId def_id, _, _, _) -> (* "Regular" ADT *) let adt_ident = match av.variant_id with @@ -213,7 +221,7 @@ module Values = struct let field_values = String.concat " " field_values in adt_ident ^ " { " ^ field_values ^ " }" else adt_ident - | T.Adt (T.Assumed aty, _, _) -> ( + | T.Adt (T.Assumed aty, _, _, _) -> ( (* Assumed type *) match (aty, field_values) with | Box, [ bv ] -> "@Box(" ^ bv ^ ")" @@ -425,6 +433,8 @@ module Contexts = struct PV.r_to_string = fmt.r_to_string; PV.type_var_id_to_string = fmt.type_var_id_to_string; PV.type_decl_id_to_string = fmt.type_decl_id_to_string; + PV.const_generic_var_id_to_string = fmt.const_generic_var_id_to_string; + PV.global_decl_id_to_string = fmt.global_decl_id_to_string; PV.adt_variant_to_string = fmt.adt_variant_to_string; PV.var_id_to_string = fmt.var_id_to_string; PV.adt_field_names = fmt.adt_field_names; @@ -450,10 +460,18 @@ module Contexts = struct let v = C.lookup_type_var ctx vid in v.name in + let const_generic_var_id_to_string vid = + let v = C.lookup_const_generic_var ctx vid in + v.name + in let type_decl_id_to_string def_id = let def = C.ctx_lookup_type_decl ctx def_id in name_to_string def.name in + let global_decl_id_to_string def_id = + let def = C.ctx_lookup_global_decl ctx def_id in + name_to_string def.name + in let adt_variant_to_string = PT.type_ctx_to_adt_variant_to_string_fun ctx.type_context.type_decls in @@ -469,6 +487,8 @@ module Contexts = struct r_to_string; type_var_id_to_string; type_decl_id_to_string; + const_generic_var_id_to_string; + global_decl_id_to_string; adt_variant_to_string; var_id_to_string; adt_field_names; @@ -492,6 +512,7 @@ module Contexts = struct r_to_string = ctx_fmt.PV.r_to_string; type_var_id_to_string = ctx_fmt.PV.type_var_id_to_string; type_decl_id_to_string = ctx_fmt.PV.type_decl_id_to_string; + const_generic_var_id_to_string = ctx_fmt.PV.const_generic_var_id_to_string; adt_variant_to_string = ctx_fmt.PV.adt_variant_to_string; var_id_to_string = ctx_fmt.PV.var_id_to_string; adt_field_names = ctx_fmt.PV.adt_field_names; diff --git a/compiler/PrintPure.ml b/compiler/PrintPure.ml index 3f35a023..03252200 100644 --- a/compiler/PrintPure.ml +++ b/compiler/PrintPure.ml @@ -55,8 +55,10 @@ let fun_name_to_string = Print.fun_name_to_string let global_name_to_string = Print.global_name_to_string let option_to_string = Print.option_to_string let type_var_to_string = Print.Types.type_var_to_string -let integer_type_to_string = Print.Types.integer_type_to_string +let integer_type_to_string = Print.PrimitiveValues.integer_type_to_string +let literal_type_to_string = Print.PrimitiveValues.literal_type_to_string let scalar_value_to_string = Print.PrimitiveValues.scalar_value_to_string +let literal_to_string = Print.PrimitiveValues.literal_to_string let mk_type_formatter (type_decls : T.type_decl TypeDeclId.Map.t) (type_params : type_var list) : type_formatter = @@ -392,7 +394,7 @@ let adt_g_value_to_string (fmt : value_formatter) let rec typed_pattern_to_string (fmt : ast_formatter) (v : typed_pattern) : string = match v.value with - | PatConstant cv -> Print.PrimitiveValues.primitive_value_to_string cv + | PatConstant cv -> literal_to_string cv | PatVar (v, None) -> var_to_string (ast_to_type_formatter fmt) v | PatVar (v, Some mp) -> let mp = "[@mplace=" ^ mplace_to_string fmt mp ^ "]" in @@ -450,6 +452,16 @@ let llbc_assumed_fun_id_to_string (fid : A.assumed_fun_id) : string = | A.VecLen -> "alloc::vec::Vec::len" | A.VecIndex -> "core::ops::index::Index::index" | A.VecIndexMut -> "core::ops::index::IndexMut::index_mut" + | ArraySharedIndex -> "@ArraySharedIndex" + | ArrayMutIndex -> "@ArrayMutIndex" + | ArrayToSharedSlice -> "@ArrayToSharedSlice" + | ArrayToMutSlice -> "@ArrayToMutSlice" + | ArraySharedSubslice -> "@ArraySharedSubslice" + | ArrayMutSubslice -> "@ArrayMutSubslice" + | SliceSharedIndex -> "@SliceSharedIndex" + | SliceMutIndex -> "@SliceMutIndex" + | SliceSharedSubslice -> "@SliceSharedSubslice" + | SliceMutSubslice -> "@SliceMutSubslice" let pure_assumed_fun_id_to_string (fid : pure_assumed_fun_id) : string = match fid with @@ -495,7 +507,7 @@ let rec texpression_to_string (fmt : ast_formatter) (inside : bool) | Var var_id -> let s = fmt.var_id_to_string var_id in if inside then "(" ^ s ^ ")" else s - | Const cv -> Print.PrimitiveValues.primitive_value_to_string cv + | Const cv -> literal_to_string cv | App _ -> (* Recursively destruct the app, to have a pair (app, arguments list) *) let app, args = destruct_apps e in diff --git a/compiler/Pure.ml b/compiler/Pure.ml index b251a005..5af28efd 100644 --- a/compiler/Pure.ml +++ b/compiler/Pure.ml @@ -187,7 +187,7 @@ type type_decl = { [@@deriving show] type scalar_value = V.scalar_value [@@deriving show] -type primitive_value = V.primitive_value [@@deriving show] +type literal = V.literal [@@deriving show] (** Because we introduce a lot of temporary variables, the list of variables is not fixed: we thus must carry all its information with the variable @@ -232,10 +232,7 @@ type variant_id = VariantId.id [@@deriving show] class ['self] iter_typed_pattern_base = object (_self : 'self) inherit [_] VisitorsRuntime.iter - - method visit_primitive_value : 'env -> primitive_value -> unit = - fun _ _ -> () - + method visit_literal : 'env -> literal -> unit = fun _ _ -> () method visit_var : 'env -> var -> unit = fun _ _ -> () method visit_mplace : 'env -> mplace -> unit = fun _ _ -> () method visit_ty : 'env -> ty -> unit = fun _ _ -> () @@ -246,10 +243,7 @@ class ['self] iter_typed_pattern_base = class ['self] map_typed_pattern_base = object (_self : 'self) inherit [_] VisitorsRuntime.map - - method visit_primitive_value : 'env -> primitive_value -> primitive_value = - fun _ x -> x - + method visit_literal : 'env -> literal -> literal = fun _ x -> x method visit_var : 'env -> var -> var = fun _ x -> x method visit_mplace : 'env -> mplace -> mplace = fun _ x -> x method visit_ty : 'env -> ty -> ty = fun _ x -> x @@ -260,10 +254,7 @@ class ['self] map_typed_pattern_base = class virtual ['self] reduce_typed_pattern_base = object (self : 'self) inherit [_] VisitorsRuntime.reduce - - method visit_primitive_value : 'env -> primitive_value -> 'a = - fun _ _ -> self#zero - + method visit_literal : 'env -> literal -> 'a = fun _ _ -> self#zero method visit_var : 'env -> var -> 'a = fun _ _ -> self#zero method visit_mplace : 'env -> mplace -> 'a = fun _ _ -> self#zero method visit_ty : 'env -> ty -> 'a = fun _ _ -> self#zero @@ -275,8 +266,7 @@ class virtual ['self] mapreduce_typed_pattern_base = object (self : 'self) inherit [_] VisitorsRuntime.mapreduce - method visit_primitive_value - : 'env -> primitive_value -> primitive_value * 'a = + method visit_literal : 'env -> literal -> literal * 'a = fun _ x -> (x, self#zero) method visit_var : 'env -> var -> var * 'a = fun _ x -> (x, self#zero) @@ -292,7 +282,7 @@ class virtual ['self] mapreduce_typed_pattern_base = (** A pattern (which appears on the left of assignments, in matches, etc.). *) type pattern = - | PatConstant of primitive_value + | PatConstant of literal (** {!PatConstant} is necessary because we merge the switches over integer values and the matches over enumerations *) | PatVar of var * mplace option @@ -486,7 +476,7 @@ class virtual ['self] mapreduce_expression_base = *) type expression = | Var of var_id (** a variable *) - | Const of primitive_value + | Const of literal | App of texpression * texpression (** Application of a function to an argument. diff --git a/compiler/PureTypeCheck.ml b/compiler/PureTypeCheck.ml index 018ea6b5..72084dfc 100644 --- a/compiler/PureTypeCheck.ml +++ b/compiler/PureTypeCheck.ml @@ -56,17 +56,17 @@ type tc_ctx = { env : ty VarId.Map.t; (** Environment from variables to types *) } -let check_primitive_value (v : primitive_value) (ty : ty) : unit = +let check_literal (v : literal) (ty : ty) : unit = match (ty, v) with | Integer int_ty, PV.Scalar sv -> assert (int_ty = sv.PV.int_ty) - | Bool, Bool _ | Char, Char _ | Str, String _ -> () + | Bool, Bool _ | Char, Char _ -> () | _ -> raise (Failure "Inconsistent type") let rec check_typed_pattern (ctx : tc_ctx) (v : typed_pattern) : tc_ctx = log#ldebug (lazy ("check_typed_pattern: " ^ show_typed_pattern v)); match v.value with | PatConstant cv -> - check_primitive_value cv v.ty; + check_literal cv v.ty; ctx | PatDummy -> ctx | PatVar (var, _) -> @@ -108,7 +108,7 @@ let rec check_texpression (ctx : tc_ctx) (e : texpression) : unit = match VarId.Map.find_opt var_id ctx.env with | None -> () | Some ty -> assert (ty = e.ty)) - | Const cv -> check_primitive_value cv e.ty + | Const cv -> check_literal cv e.ty | App (app, arg) -> let input_ty, output_ty = destruct_arrow app.ty in assert (input_ty = arg.ty); @@ -197,9 +197,9 @@ let rec check_texpression (ctx : tc_ctx) (e : texpression) : unit = | StructUpdate supd -> (* Check the init value *) (if Option.is_some supd.init then - match VarId.Map.find_opt (Option.get supd.init) ctx.env with - | None -> () - | Some ty -> assert (ty = e.ty)); + match VarId.Map.find_opt (Option.get supd.init) ctx.env with + | None -> () + | Some ty -> assert (ty = e.ty)); (* Check the fields *) (* Retrieve and check the expected field type *) let adt_id, adt_type_args = diff --git a/compiler/PureUtils.ml b/compiler/PureUtils.ml index 647678c1..88b18e89 100644 --- a/compiler/PureUtils.ml +++ b/compiler/PureUtils.ml @@ -62,18 +62,16 @@ let dest_arrow_ty (ty : ty) : ty * ty = | Arrow (arg_ty, ret_ty) -> (arg_ty, ret_ty) | _ -> raise (Failure "Unreachable") -let compute_primitive_value_ty (cv : primitive_value) : ty = +let compute_literal_ty (cv : literal) : ty = match cv with | PV.Scalar sv -> Integer sv.PV.int_ty | Bool _ -> Bool | Char _ -> Char - | String _ -> Str let var_get_id (v : var) : VarId.id = v.id -let mk_typed_pattern_from_primitive_value (cv : primitive_value) : typed_pattern - = - let ty = compute_primitive_value_ty cv in +let mk_typed_pattern_from_literal (cv : literal) : typed_pattern = + let ty = compute_literal_ty cv in { value = PatConstant cv; ty } let mk_let (monadic : bool) (lv : typed_pattern) (re : texpression) diff --git a/compiler/Substitute.ml b/compiler/Substitute.ml index 5040fd9f..a1b1572e 100644 --- a/compiler/Substitute.ml +++ b/compiler/Substitute.ml @@ -11,7 +11,8 @@ module C = Contexts (** Substitute types variables and regions in a type. *) let ty_substitute (rsubst : 'r1 -> 'r2) (tsubst : T.TypeVarId.id -> 'r2 T.ty) - (ty : 'r1 T.ty) : 'r2 T.ty = + (cgsubst : T.ConstGenericVarId.id -> T.const_generic) (ty : 'r1 T.ty) : + 'r2 T.ty = let open T in let visitor = object @@ -22,25 +23,37 @@ let ty_substitute (rsubst : 'r1 -> 'r2) (tsubst : T.TypeVarId.id -> 'r2 T.ty) method! visit_type_var_id _ _ = (* We should never get here because we reimplemented [visit_TypeVar] *) raise (Failure "Unexpected") + + method! visit_ConstGenericVar _ id = cgsubst id + + method! visit_const_generic_var_id _ _ = + (* We should never get here because we reimplemented [visit_Var] *) + raise (Failure "Unexpected") end in visitor#visit_ty () ty let rty_substitute (rsubst : T.RegionId.id -> T.RegionId.id) - (tsubst : T.TypeVarId.id -> T.rty) (ty : T.rty) : T.rty = + (tsubst : T.TypeVarId.id -> T.rty) + (cgsubst : T.ConstGenericVarId.id -> T.const_generic) (ty : T.rty) : T.rty = let rsubst r = match r with T.Static -> T.Static | T.Var rid -> T.Var (rsubst rid) in - ty_substitute rsubst tsubst ty + ty_substitute rsubst tsubst cgsubst ty -let ety_substitute (tsubst : T.TypeVarId.id -> T.ety) (ty : T.ety) : T.ety = +let ety_substitute (tsubst : T.TypeVarId.id -> T.ety) + (cgsubst : T.ConstGenericVarId.id -> T.const_generic) (ty : T.ety) : T.ety = let rsubst r = r in - ty_substitute rsubst tsubst ty + ty_substitute rsubst tsubst cgsubst ty (** Convert an {!T.rty} to an {!T.ety} by erasing the region variables *) let erase_regions (ty : T.rty) : T.ety = - ty_substitute (fun _ -> T.Erased) (fun vid -> T.TypeVar vid) ty + ty_substitute + (fun _ -> T.Erased) + (fun vid -> T.TypeVar vid) + (fun id -> T.ConstGenericVar id) + ty (** Generate fresh regions for region variables. @@ -59,7 +72,7 @@ let fresh_regions_with_substs (region_vars : T.region_var list) : let ls = List.combine region_vars fresh_region_ids in let rid_map = List.fold_left - (fun mp (k, v) -> T.RegionVarId.Map.add k.T.index v mp) + (fun mp ((k : T.region_var), v) -> T.RegionVarId.Map.add k.T.index v mp) T.RegionVarId.Map.empty ls in (* Generate the substitution from region var id to region *) @@ -73,9 +86,10 @@ let fresh_regions_with_substs (region_vars : T.region_var list) : (** Erase the regions in a type and substitute the type variables *) let erase_regions_substitute_types (tsubst : T.TypeVarId.id -> T.ety) + (cgsubst : T.ConstGenericVarId.id -> T.const_generic) (ty : 'r T.region T.ty) : T.ety = let rsubst (_ : 'r T.region) : T.erased_region = T.Erased in - ty_substitute rsubst tsubst ty + ty_substitute rsubst tsubst cgsubst ty (** Create a region substitution from a list of region variable ids and a list of regions (with which to substitute the region variable ids *) @@ -92,6 +106,12 @@ let make_region_subst (var_ids : T.RegionVarId.id list) | T.Static -> T.Static | T.Var id -> T.RegionVarId.Map.find id mp +let make_region_subst_from_vars (vars : T.region_var list) + (regions : 'r T.region list) : T.RegionVarId.id T.region -> 'r T.region = + make_region_subst + (List.map (fun (x : T.region_var) -> x.T.index) vars) + regions + (** Create a type substitution from a list of type variable ids and a list of types (with which to substitute the type variable ids) *) let make_type_subst (var_ids : T.TypeVarId.id list) (tys : 'r T.ty list) : @@ -104,18 +124,37 @@ let make_type_subst (var_ids : T.TypeVarId.id list) (tys : 'r T.ty list) : in fun id -> T.TypeVarId.Map.find id mp +let make_type_subst_from_vars (vars : T.type_var list) (tys : 'r T.ty list) : + T.TypeVarId.id -> 'r T.ty = + make_type_subst (List.map (fun (x : T.type_var) -> x.T.index) vars) tys + +(** Create a const generic substitution from a list of const generic variable ids and a list of + const generics (with which to substitute the const generic variable ids) *) +let make_const_generic_subst (var_ids : T.ConstGenericVarId.id list) + (cgs : T.const_generic list) : T.ConstGenericVarId.id -> T.const_generic = + let ls = List.combine var_ids cgs in + let mp = + List.fold_left + (fun mp (k, v) -> T.ConstGenericVarId.Map.add k v mp) + T.ConstGenericVarId.Map.empty ls + in + fun id -> T.ConstGenericVarId.Map.find id mp + +let make_const_generic_subst_from_vars (vars : T.const_generic_var list) + (cgs : T.const_generic list) : T.ConstGenericVarId.id -> T.const_generic = + make_const_generic_subst + (List.map (fun (x : T.const_generic_var) -> x.T.index) vars) + cgs + (** Instantiate the type variables in an ADT definition, and return, for every variant, the list of the types of its fields *) let type_decl_get_instantiated_variants_fields_rtypes (def : T.type_decl) - (regions : T.RegionId.id T.region list) (types : T.rty list) : - (T.VariantId.id option * T.rty list) list = - let r_subst = - make_region_subst - (List.map (fun x -> x.T.index) def.T.region_params) - regions - in - let ty_subst = - make_type_subst (List.map (fun x -> x.T.index) def.T.type_params) types + (regions : T.RegionId.id T.region list) (types : T.rty list) + (cgs : T.const_generic list) : (T.VariantId.id option * T.rty list) list = + let r_subst = make_region_subst_from_vars def.T.region_params regions in + let ty_subst = make_type_subst_from_vars def.T.type_params types in + let cg_subst = + make_const_generic_subst_from_vars def.T.const_generic_params cgs in let (variants_fields : (T.VariantId.id option * T.field list) list) = match def.T.kind with @@ -133,45 +172,47 @@ let type_decl_get_instantiated_variants_fields_rtypes (def : T.type_decl) List.map (fun (id, fields) -> ( id, - List.map (fun f -> ty_substitute r_subst ty_subst f.T.field_ty) fields - )) + List.map + (fun f -> ty_substitute r_subst ty_subst cg_subst f.T.field_ty) + fields )) variants_fields (** Instantiate the type variables in an ADT definition, and return the list of types of the fields for the chosen variant *) let type_decl_get_instantiated_field_rtypes (def : T.type_decl) (opt_variant_id : T.VariantId.id option) - (regions : T.RegionId.id T.region list) (types : T.rty list) : T.rty list = - let r_subst = - make_region_subst - (List.map (fun x -> x.T.index) def.T.region_params) - regions - in - let ty_subst = - make_type_subst (List.map (fun x -> x.T.index) def.T.type_params) types + (regions : T.RegionId.id T.region list) (types : T.rty list) + (cgs : T.const_generic list) : T.rty list = + let r_subst = make_region_subst_from_vars def.T.region_params regions in + let ty_subst = make_type_subst_from_vars def.T.type_params types in + let cg_subst = + make_const_generic_subst_from_vars def.T.const_generic_params cgs in let fields = TU.type_decl_get_fields def opt_variant_id in - List.map (fun f -> ty_substitute r_subst ty_subst f.T.field_ty) fields + List.map + (fun f -> ty_substitute r_subst ty_subst cg_subst f.T.field_ty) + fields (** Return the types of the properly instantiated ADT's variant, provided a context *) let ctx_adt_get_instantiated_field_rtypes (ctx : C.eval_ctx) (def_id : T.TypeDeclId.id) (opt_variant_id : T.VariantId.id option) - (regions : T.RegionId.id T.region list) (types : T.rty list) : T.rty list = + (regions : T.RegionId.id T.region list) (types : T.rty list) + (cgs : T.const_generic list) : T.rty list = let def = C.ctx_lookup_type_decl ctx def_id in - type_decl_get_instantiated_field_rtypes def opt_variant_id regions types + type_decl_get_instantiated_field_rtypes def opt_variant_id regions types cgs (** Return the types of the properly instantiated ADT value (note that here, ADT is understood in its broad meaning: ADT, assumed value or tuple) *) let ctx_adt_value_get_instantiated_field_rtypes (ctx : C.eval_ctx) (adt : V.adt_value) (id : T.type_id) - (region_params : T.RegionId.id T.region list) (type_params : T.rty list) : - T.rty list = + (region_params : T.RegionId.id T.region list) (type_params : T.rty list) + (cg_params : T.const_generic list) : T.rty list = match id with | T.AdtId id -> (* Retrieve the types of the fields *) ctx_adt_get_instantiated_field_rtypes ctx id adt.V.variant_id - region_params type_params + region_params type_params cg_params | T.Tuple -> assert (List.length region_params = 0); type_params @@ -180,182 +221,116 @@ let ctx_adt_value_get_instantiated_field_rtypes (ctx : C.eval_ctx) | T.Box | T.Vec -> assert (List.length region_params = 0); assert (List.length type_params = 1); + assert (List.length cg_params = 0); type_params | T.Option -> assert (List.length region_params = 0); assert (List.length type_params = 1); + assert (List.length cg_params = 0); if adt.V.variant_id = Some T.option_some_id then type_params else if adt.V.variant_id = Some T.option_none_id then [] - else raise (Failure "Unreachable")) + else raise (Failure "Unreachable") + | T.Array | T.Slice | T.Str -> + (* Those types don't have fields *) + raise (Failure "Unreachable")) (** Instantiate the type variables in an ADT definition, and return the list of types of the fields for the chosen variant *) let type_decl_get_instantiated_field_etypes (def : T.type_decl) - (opt_variant_id : T.VariantId.id option) (types : T.ety list) : T.ety list = - let ty_subst = - make_type_subst (List.map (fun x -> x.T.index) def.T.type_params) types + (opt_variant_id : T.VariantId.id option) (types : T.ety list) + (cgs : T.const_generic list) : T.ety list = + let ty_subst = make_type_subst_from_vars def.T.type_params types in + let cg_subst = + make_const_generic_subst_from_vars def.T.const_generic_params cgs in let fields = TU.type_decl_get_fields def opt_variant_id in List.map - (fun f -> erase_regions_substitute_types ty_subst f.T.field_ty) + (fun f -> erase_regions_substitute_types ty_subst cg_subst f.T.field_ty) fields (** Return the types of the properly instantiated ADT's variant, provided a context *) let ctx_adt_get_instantiated_field_etypes (ctx : C.eval_ctx) (def_id : T.TypeDeclId.id) (opt_variant_id : T.VariantId.id option) - (types : T.ety list) : T.ety list = + (types : T.ety list) (cgs : T.const_generic list) : T.ety list = let def = C.ctx_lookup_type_decl ctx def_id in - type_decl_get_instantiated_field_etypes def opt_variant_id types + type_decl_get_instantiated_field_etypes def opt_variant_id types cgs + +let statement_substitute_visitor (tsubst : T.TypeVarId.id -> T.ety) + (cgsubst : T.ConstGenericVarId.id -> T.const_generic) = + object + inherit [_] A.map_statement + method! visit_ety _ ty = ety_substitute tsubst cgsubst ty + method! visit_ConstGenericVar _ id = cgsubst id + + method! visit_const_generic_var_id _ _ = + (* We should never get here because we reimplemented [visit_Var] *) + raise (Failure "Unexpected") + end (** Apply a type substitution to a place *) -let place_substitute (_tsubst : T.TypeVarId.id -> T.ety) (p : E.place) : E.place - = - (* There is nothing to do *) - p +let place_substitute (tsubst : T.TypeVarId.id -> T.ety) + (cgsubst : T.ConstGenericVarId.id -> T.const_generic) (p : E.place) : + E.place = + (* There is in fact nothing to do *) + (statement_substitute_visitor tsubst cgsubst)#visit_place () p (** Apply a type substitution to an operand *) -let operand_substitute (tsubst : T.TypeVarId.id -> T.ety) (op : E.operand) : +let operand_substitute (tsubst : T.TypeVarId.id -> T.ety) + (cgsubst : T.ConstGenericVarId.id -> T.const_generic) (op : E.operand) : E.operand = - let p_subst = place_substitute tsubst in - match op with - | E.Copy p -> E.Copy (p_subst p) - | E.Move p -> E.Move (p_subst p) - | E.Constant (ety, cv) -> - let rsubst x = x in - E.Constant (ty_substitute rsubst tsubst ety, cv) + (statement_substitute_visitor tsubst cgsubst)#visit_operand () op (** Apply a type substitution to an rvalue *) -let rvalue_substitute (tsubst : T.TypeVarId.id -> T.ety) (rv : E.rvalue) : +let rvalue_substitute (tsubst : T.TypeVarId.id -> T.ety) + (cgsubst : T.ConstGenericVarId.id -> T.const_generic) (rv : E.rvalue) : E.rvalue = - let op_subst = operand_substitute tsubst in - let p_subst = place_substitute tsubst in - match rv with - | E.Use op -> E.Use (op_subst op) - | E.Ref (p, bkind) -> E.Ref (p_subst p, bkind) - | E.UnaryOp (unop, op) -> E.UnaryOp (unop, op_subst op) - | E.BinaryOp (binop, op1, op2) -> - E.BinaryOp (binop, op_subst op1, op_subst op2) - | E.Discriminant p -> E.Discriminant (p_subst p) - | E.Global _ -> (* Globals don't have type parameters *) rv - | E.Aggregate (kind, ops) -> - let ops = List.map op_subst ops in - let kind = - match kind with - | E.AggregatedTuple -> E.AggregatedTuple - | E.AggregatedOption (variant_id, ty) -> - let rsubst r = r in - E.AggregatedOption (variant_id, ty_substitute rsubst tsubst ty) - | E.AggregatedAdt (def_id, variant_id, regions, tys) -> - let rsubst r = r in - E.AggregatedAdt - ( def_id, - variant_id, - regions, - List.map (ty_substitute rsubst tsubst) tys ) - in - E.Aggregate (kind, ops) + (statement_substitute_visitor tsubst cgsubst)#visit_rvalue () rv (** Apply a type substitution to an assertion *) -let assertion_substitute (tsubst : T.TypeVarId.id -> T.ety) (a : A.assertion) : +let assertion_substitute (tsubst : T.TypeVarId.id -> T.ety) + (cgsubst : T.ConstGenericVarId.id -> T.const_generic) (a : A.assertion) : A.assertion = - { A.cond = operand_substitute tsubst a.A.cond; A.expected = a.A.expected } + (statement_substitute_visitor tsubst cgsubst)#visit_assertion () a (** Apply a type substitution to a call *) -let call_substitute (tsubst : T.TypeVarId.id -> T.ety) (call : A.call) : A.call - = - let rsubst x = x in - let type_args = List.map (ty_substitute rsubst tsubst) call.A.type_args in - let args = List.map (operand_substitute tsubst) call.A.args in - let dest = place_substitute tsubst call.A.dest in - (* Putting all the paramters on purpose: we want to get a compiler error if - something moves - we may add a field on which we need to apply a substitution *) - { - func = call.A.func; - region_args = call.A.region_args; - A.type_args; - args; - dest; - } - -(** Apply a type substitution to a statement - TODO: use a map iterator *) -let rec statement_substitute (tsubst : T.TypeVarId.id -> T.ety) - (st : A.statement) : A.statement = - { st with A.content = raw_statement_substitute tsubst st.content } - -and raw_statement_substitute (tsubst : T.TypeVarId.id -> T.ety) - (st : A.raw_statement) : A.raw_statement = - match st with - | A.Assign (p, rvalue) -> - let p = place_substitute tsubst p in - let rvalue = rvalue_substitute tsubst rvalue in - A.Assign (p, rvalue) - | A.FakeRead p -> - let p = place_substitute tsubst p in - A.FakeRead p - | A.SetDiscriminant (p, vid) -> - let p = place_substitute tsubst p in - A.SetDiscriminant (p, vid) - | A.Drop p -> - let p = place_substitute tsubst p in - A.Drop p - | A.Assert assertion -> - let assertion = assertion_substitute tsubst assertion in - A.Assert assertion - | A.Call call -> - let call = call_substitute tsubst call in - A.Call call - | A.Panic | A.Return | A.Break _ | A.Continue _ | A.Nop -> st - | A.Sequence (st1, st2) -> - A.Sequence - (statement_substitute tsubst st1, statement_substitute tsubst st2) - | A.Switch switch -> A.Switch (switch_substitute tsubst switch) - | A.Loop le -> A.Loop (statement_substitute tsubst le) - -(** Apply a type substitution to a switch *) -and switch_substitute (tsubst : T.TypeVarId.id -> T.ety) (switch : A.switch) : - A.switch = - match switch with - | A.If (op, st1, st2) -> - A.If - ( operand_substitute tsubst op, - statement_substitute tsubst st1, - statement_substitute tsubst st2 ) - | A.SwitchInt (op, int_ty, tgts, otherwise) -> - let tgts = - List.map (fun (sv, st) -> (sv, statement_substitute tsubst st)) tgts - in - let otherwise = statement_substitute tsubst otherwise in - A.SwitchInt (operand_substitute tsubst op, int_ty, tgts, otherwise) - | A.Match (p, tgts, otherwise) -> - let tgts = - List.map (fun (sv, st) -> (sv, statement_substitute tsubst st)) tgts - in - let otherwise = statement_substitute tsubst otherwise in - A.Match (place_substitute tsubst p, tgts, otherwise) +let call_substitute (tsubst : T.TypeVarId.id -> T.ety) + (cgsubst : T.ConstGenericVarId.id -> T.const_generic) (call : A.call) : + A.call = + (statement_substitute_visitor tsubst cgsubst)#visit_call () call + +(** Apply a type substitution to a statement *) +let statement_substitute (tsubst : T.TypeVarId.id -> T.ety) + (cgsubst : T.ConstGenericVarId.id -> T.const_generic) (st : A.statement) : + A.statement = + (statement_substitute_visitor tsubst cgsubst)#visit_statement () st (** Apply a type substitution to a function body. Return the local variables and the body. *) let fun_body_substitute_in_body (tsubst : T.TypeVarId.id -> T.ety) - (body : A.fun_body) : A.var list * A.statement = + (cgsubst : T.ConstGenericVarId.id -> T.const_generic) (body : A.fun_body) : + A.var list * A.statement = let rsubst r = r in let locals = List.map - (fun v -> { v with A.var_ty = ty_substitute rsubst tsubst v.A.var_ty }) + (fun (v : A.var) -> + { v with A.var_ty = ty_substitute rsubst tsubst cgsubst v.A.var_ty }) body.A.locals in - let body = statement_substitute tsubst body.body in + let body = statement_substitute tsubst cgsubst body.body in (locals, body) (** Substitute a function signature *) let substitute_signature (asubst : T.RegionGroupId.id -> V.AbstractionId.id) (rsubst : T.RegionVarId.id -> T.RegionId.id) - (tsubst : T.TypeVarId.id -> T.rty) (sg : A.fun_sig) : A.inst_fun_sig = + (tsubst : T.TypeVarId.id -> T.rty) + (cgsubst : T.ConstGenericVarId.id -> T.const_generic) (sg : A.fun_sig) : + A.inst_fun_sig = let rsubst' (r : T.RegionVarId.id T.region) : T.RegionId.id T.region = match r with T.Static -> T.Static | T.Var rid -> T.Var (rsubst rid) in - let inputs = List.map (ty_substitute rsubst' tsubst) sg.A.inputs in - let output = ty_substitute rsubst' tsubst sg.A.output in + let inputs = List.map (ty_substitute rsubst' tsubst cgsubst) sg.A.inputs in + let output = ty_substitute rsubst' tsubst cgsubst sg.A.output in let subst_region_group (rg : T.region_var_group) : A.abs_region_group = let id = asubst rg.id in let regions = List.map rsubst rg.regions in @@ -366,7 +341,8 @@ let substitute_signature (asubst : T.RegionGroupId.id -> V.AbstractionId.id) { A.regions_hierarchy; inputs; output } (** Substitute type variable identifiers in a type *) -let ty_substitute_ids (tsubst : T.TypeVarId.id -> T.TypeVarId.id) (ty : 'r T.ty) +let ty_substitute_ids (tsubst : T.TypeVarId.id -> T.TypeVarId.id) + (cgsubst : T.ConstGenericVarId.id -> T.ConstGenericVarId.id) (ty : 'r T.ty) : 'r T.ty = let open T in let visitor = @@ -374,6 +350,7 @@ let ty_substitute_ids (tsubst : T.TypeVarId.id -> T.TypeVarId.id) (ty : 'r T.ty) inherit [_] map_ty method visit_'r _ r = r method! visit_type_var_id _ id = tsubst id + method! visit_const_generic_var_id _ id = cgsubst id end in @@ -384,7 +361,7 @@ let ty_substitute_ids (tsubst : T.TypeVarId.id -> T.TypeVarId.id) (ty : 'r T.ty) We want to write a class which visits abstractions, values, etc. *and their types* to substitute identifiers. - The issue comes is that we derive two specialized types (ety and rty) from a + The issue is that we derive two specialized types (ety and rty) from a polymorphic type (ty). Because of this, there are typing issues with [visit_'r] if we define a class which visits objects of types [ety] and [rty] while inheriting a class which visit [ty]... @@ -392,6 +369,7 @@ let ty_substitute_ids (tsubst : T.TypeVarId.id -> T.TypeVarId.id) (ty : 'r T.ty) let subst_ids_visitor (rsubst : T.RegionId.id -> T.RegionId.id) (rvsubst : T.RegionVarId.id -> T.RegionVarId.id) (tsubst : T.TypeVarId.id -> T.TypeVarId.id) + (cgsubst : T.ConstGenericVarId.id -> T.ConstGenericVarId.id) (ssubst : V.SymbolicValueId.id -> V.SymbolicValueId.id) (bsubst : V.BorrowId.id -> V.BorrowId.id) (asubst : V.AbstractionId.id -> V.AbstractionId.id) = @@ -403,6 +381,7 @@ let subst_ids_visitor (rsubst : T.RegionId.id -> T.RegionId.id) match r with T.Static -> T.Static | T.Var rid -> T.Var (rsubst rid) method! visit_type_var_id _ id = tsubst id + method! visit_const_generic_var_id _ id = cgsubst id end in @@ -411,7 +390,7 @@ let subst_ids_visitor (rsubst : T.RegionId.id -> T.RegionId.id) inherit [_] C.map_env method! visit_borrow_id _ bid = bsubst bid method! visit_loan_id _ bid = bsubst bid - method! visit_ety _ ty = ty_substitute_ids tsubst ty + method! visit_ety _ ty = ty_substitute_ids tsubst cgsubst ty method! visit_rty env ty = subst_rty#visit_ty env ty method! visit_symbolic_value_id _ id = ssubst id @@ -444,11 +423,12 @@ let subst_ids_visitor (rsubst : T.RegionId.id -> T.RegionId.id) let typed_value_subst_ids (rsubst : T.RegionId.id -> T.RegionId.id) (rvsubst : T.RegionVarId.id -> T.RegionVarId.id) (tsubst : T.TypeVarId.id -> T.TypeVarId.id) + (cgsubst : T.ConstGenericVarId.id -> T.ConstGenericVarId.id) (ssubst : V.SymbolicValueId.id -> V.SymbolicValueId.id) (bsubst : V.BorrowId.id -> V.BorrowId.id) (v : V.typed_value) : V.typed_value = let asubst _ = raise (Failure "Unreachable") in - (subst_ids_visitor rsubst rvsubst tsubst ssubst bsubst asubst) + (subst_ids_visitor rsubst rvsubst tsubst cgsubst ssubst bsubst asubst) #visit_typed_value v let typed_value_subst_rids (rsubst : T.RegionId.id -> T.RegionId.id) @@ -458,33 +438,39 @@ let typed_value_subst_rids (rsubst : T.RegionId.id -> T.RegionId.id) (fun x -> x) (fun x -> x) (fun x -> x) + (fun x -> x) v let typed_avalue_subst_ids (rsubst : T.RegionId.id -> T.RegionId.id) (rvsubst : T.RegionVarId.id -> T.RegionVarId.id) (tsubst : T.TypeVarId.id -> T.TypeVarId.id) + (cgsubst : T.ConstGenericVarId.id -> T.ConstGenericVarId.id) (ssubst : V.SymbolicValueId.id -> V.SymbolicValueId.id) (bsubst : V.BorrowId.id -> V.BorrowId.id) (v : V.typed_avalue) : V.typed_avalue = let asubst _ = raise (Failure "Unreachable") in - (subst_ids_visitor rsubst rvsubst tsubst ssubst bsubst asubst) + (subst_ids_visitor rsubst rvsubst tsubst cgsubst ssubst bsubst asubst) #visit_typed_avalue v let abs_subst_ids (rsubst : T.RegionId.id -> T.RegionId.id) (rvsubst : T.RegionVarId.id -> T.RegionVarId.id) (tsubst : T.TypeVarId.id -> T.TypeVarId.id) + (cgsubst : T.ConstGenericVarId.id -> T.ConstGenericVarId.id) (ssubst : V.SymbolicValueId.id -> V.SymbolicValueId.id) (bsubst : V.BorrowId.id -> V.BorrowId.id) (asubst : V.AbstractionId.id -> V.AbstractionId.id) (x : V.abs) : V.abs = - (subst_ids_visitor rsubst rvsubst tsubst ssubst bsubst asubst)#visit_abs x + (subst_ids_visitor rsubst rvsubst tsubst cgsubst ssubst bsubst asubst) + #visit_abs x let env_subst_ids (rsubst : T.RegionId.id -> T.RegionId.id) (rvsubst : T.RegionVarId.id -> T.RegionVarId.id) (tsubst : T.TypeVarId.id -> T.TypeVarId.id) + (cgsubst : T.ConstGenericVarId.id -> T.ConstGenericVarId.id) (ssubst : V.SymbolicValueId.id -> V.SymbolicValueId.id) (bsubst : V.BorrowId.id -> V.BorrowId.id) (asubst : V.AbstractionId.id -> V.AbstractionId.id) (x : C.env) : C.env = - (subst_ids_visitor rsubst rvsubst tsubst ssubst bsubst asubst)#visit_env x + (subst_ids_visitor rsubst rvsubst tsubst cgsubst ssubst bsubst asubst) + #visit_env x let typed_avalue_subst_rids (rsubst : T.RegionId.id -> T.RegionId.id) (x : V.typed_avalue) : V.typed_avalue = @@ -494,6 +480,7 @@ let typed_avalue_subst_rids (rsubst : T.RegionId.id -> T.RegionId.id) (fun x -> x) (fun x -> x) (fun x -> x) + (fun x -> x) asubst) #visit_typed_avalue x @@ -505,6 +492,7 @@ let env_subst_rids (rsubst : T.RegionId.id -> T.RegionId.id) (x : C.env) : C.env (fun x -> x) (fun x -> x) (fun x -> x) + (fun x -> x) (fun x -> x)) #visit_env x diff --git a/compiler/SynthesizeSymbolic.ml b/compiler/SynthesizeSymbolic.ml index 6668c043..a6e11363 100644 --- a/compiler/SynthesizeSymbolic.ml +++ b/compiler/SynthesizeSymbolic.ml @@ -32,7 +32,7 @@ let synthesize_symbolic_expansion (sv : V.symbolic_value) (* Match on the symbolic value type to know which can of expansion happened *) let expansion = match sv.V.sv_ty with - | T.Bool -> ( + | T.Literal PV.Bool -> ( (* Boolean expansion: there should be two branches *) match ls with | [ @@ -41,7 +41,7 @@ let synthesize_symbolic_expansion (sv : V.symbolic_value) ] -> ExpandBool (true_exp, false_exp) | _ -> raise (Failure "Ill-formed boolean expansion")) - | T.Integer int_ty -> + | T.Literal (PV.Integer int_ty) -> (* Switch over an integer: split between the "regular" branches and the "otherwise" branch (which should be the last branch) *) let branches, otherwise = C.List.pop_last ls in @@ -64,7 +64,7 @@ let synthesize_symbolic_expansion (sv : V.symbolic_value) assert (otherwise_see = None); (* Return *) ExpandInt (int_ty, branches, otherwise) - | T.Adt (_, _, _) -> + | T.Adt (_, _, _, _) -> (* Branching: it is necessarily an enumeration expansion *) let get_variant (see : V.symbolic_expansion option) : T.VariantId.id option * V.symbolic_value list = @@ -85,7 +85,7 @@ let synthesize_symbolic_expansion (sv : V.symbolic_value) match ls with | [ (Some see, exp) ] -> ExpandNoBranch (see, exp) | _ -> raise (Failure "Ill-formed borrow expansion")) - | T.TypeVar _ | Char | Never | Str | Array _ | Slice _ -> + | T.TypeVar _ | T.Literal Char | Never -> raise (Failure "Ill-formed symbolic expansion") in Some (Expansion (place, sv, expansion)) diff --git a/compiler/TypesAnalysis.ml b/compiler/TypesAnalysis.ml index db6d3b98..b4bb0386 100644 --- a/compiler/TypesAnalysis.ml +++ b/compiler/TypesAnalysis.ml @@ -122,7 +122,7 @@ let analyze_full_ty (r_is_static : 'r -> bool) (updated : bool ref) let rec analyze (expl_info : expl_info) (ty_info : partial_type_info) (ty : 'r ty) : partial_type_info = match ty with - | Bool | Char | Never | Integer _ | Str -> ty_info + | Literal _ | Never -> ty_info | TypeVar var_id -> ( (* Update the information for the proper parameter, if necessary *) match ty_info.param_infos with @@ -145,9 +145,6 @@ let analyze_full_ty (r_is_static : 'r -> bool) (updated : bool ref) in let param_infos = Some param_infos in { ty_info with param_infos }) - | Array ty | Slice ty -> - (* Just dive in *) - analyze expl_info ty_info ty | Ref (r, rty, rkind) -> (* Update the type info *) let contains_static = r_is_static r in @@ -172,12 +169,14 @@ let analyze_full_ty (r_is_static : 'r -> bool) (updated : bool ref) in (* Continue exploring *) analyze expl_info ty_info rty - | Adt ((Tuple | Assumed (Box | Vec | Option)), _, tys) -> + | Adt + ((Tuple | Assumed (Box | Vec | Option | Slice | Array | Str)), _, tys, _) + -> (* Nothing to update: just explore the type parameters *) List.fold_left (fun ty_info ty -> analyze expl_info ty_info ty) ty_info tys - | Adt (AdtId adt_id, regions, tys) -> + | Adt (AdtId adt_id, regions, tys, _cgs) -> (* Lookup the information for this type definition *) let adt_info = TypeDeclId.Map.find adt_id infos in (* Update the type info with the information from the adt *) diff --git a/compiler/Values.ml b/compiler/Values.ml index d4eff58a..3d6bc9c1 100644 --- a/compiler/Values.ml +++ b/compiler/Values.ml @@ -14,7 +14,7 @@ module LoopId = IdGen () type big_int = PrimitiveValues.big_int [@@deriving show, ord] type scalar_value = PrimitiveValues.scalar_value [@@deriving show, ord] -type primitive_value = PrimitiveValues.primitive_value [@@deriving show, ord] +type literal = PrimitiveValues.literal [@@deriving show, ord] type symbolic_value_id = SymbolicValueId.id [@@deriving show, ord] type symbolic_value_id_set = SymbolicValueId.Set.t [@@deriving show, ord] type loop_id = LoopId.id [@@deriving show, ord] @@ -110,10 +110,7 @@ type loan_id_set = BorrowId.Set.t [@@deriving show, ord] class ['self] iter_typed_value_base = object (self : 'self) inherit [_] iter_symbolic_value - - method visit_primitive_value : 'env -> primitive_value -> unit = - fun _ _ -> () - + method visit_literal : 'env -> literal -> unit = fun _ _ -> () method visit_erased_region : 'env -> erased_region -> unit = fun _ _ -> () method visit_variant_id : 'env -> variant_id -> unit = fun _ _ -> () method visit_ety : 'env -> ety -> unit = fun _ _ -> () @@ -131,9 +128,7 @@ class ['self] iter_typed_value_base = class ['self] map_typed_value_base = object (self : 'self) inherit [_] map_symbolic_value - - method visit_primitive_value : 'env -> primitive_value -> primitive_value = - fun _ cv -> cv + method visit_literal : 'env -> literal -> literal = fun _ cv -> cv method visit_erased_region : 'env -> erased_region -> erased_region = fun _ r -> r @@ -152,7 +147,7 @@ class ['self] map_typed_value_base = (** An untyped value, used in the environments *) type value = - | Primitive of primitive_value (** Non-symbolic primitive value *) + | Primitive of literal (** Non-symbolic primitive value *) | Adt of adt_value (** Enumerations and structures *) | Bottom (** No value (uninitialized or moved value) *) | Borrow of borrow_content (** A borrowed value *) @@ -1019,7 +1014,7 @@ type abs = { TODO: this should rather be name "expanded_symbolic" *) type symbolic_expansion = - | SePrimitive of primitive_value + | SePrimitive of literal | SeAdt of (VariantId.id option * symbolic_value list) | SeMutRef of BorrowId.id * symbolic_value | SeSharedRef of BorrowId.Set.t * symbolic_value -- cgit v1.2.3