diff options
author | Son Ho | 2022-02-08 23:00:17 +0100 |
---|---|---|
committer | Son Ho | 2022-02-08 23:00:17 +0100 |
commit | 5703ce3122bcfb69285a7f04abc8d80313a0747a (patch) | |
tree | a424f6dc1bb0598e3e47f1a3cc2ec4e15607dc91 /src | |
parent | 229a9881fa26dce69b81524445045e7b1efcc6fc (diff) |
Add type checking utilities for the pure ADT
Diffstat (limited to '')
-rw-r--r-- | src/PrintPure.ml | 18 | ||||
-rw-r--r-- | src/PureUtils.ml | 107 | ||||
-rw-r--r-- | src/SymbolicToPure.ml | 153 | ||||
-rw-r--r-- | src/Translate.ml | 14 |
4 files changed, 224 insertions, 68 deletions
diff --git a/src/PrintPure.ml b/src/PrintPure.ml index cf865a54..f66aadfb 100644 --- a/src/PrintPure.ml +++ b/src/PrintPure.ml @@ -262,28 +262,36 @@ let adt_g_value_to_string (fmt : value_formatter) if variant_id = result_return_id then match field_values with | [ v ] -> "@Result::Return " ^ v - | _ -> failwith "Result::Return takes exactly one value" + | _ -> raise (Failure "Result::Return takes exactly one value") else if variant_id = result_fail_id then ( assert (field_values = []); "@Result::Fail") - else failwith "Unreachable: improper variant id for result type" + else + raise (Failure "Unreachable: improper variant id for result type") | Option -> let variant_id = Option.get variant_id in if variant_id = option_some_id then match field_values with | [ v ] -> "@Option::Some " ^ v - | _ -> failwith "Option::Some takes exactly one value" + | _ -> raise (Failure "Option::Some takes exactly one value") else if variant_id = option_none_id then ( assert (field_values = []); "@Option::None") - else failwith "Unreachable: improper variant id for result type" + else + raise (Failure "Unreachable: improper variant id for result type") | Vec -> assert (variant_id = None); let field_values = List.mapi (fun i v -> string_of_int i ^ " -> " ^ v) field_values in "Vec [" ^ String.concat "; " field_values ^ "]") - | _ -> failwith "Inconsistent typed value" + | _ -> + let fmt = value_to_type_formatter fmt in + raise + (Failure + ("Inconsistently typed value: expected ADT type but found:" + ^ "\n- ty: " ^ ty_to_string fmt ty ^ "\n- variant_id: " + ^ Print.option_to_string VariantId.to_string variant_id)) let rec typed_lvalue_to_string (fmt : value_formatter) (v : typed_lvalue) : string = diff --git a/src/PureUtils.ml b/src/PureUtils.ml index 6ae8184d..662902e6 100644 --- a/src/PureUtils.ml +++ b/src/PureUtils.ml @@ -170,6 +170,34 @@ let make_type_subst (vars : type_var list) (tys : ty list) : TypeVarId.id -> ty in fun id -> TypeVarId.Map.find id mp +(** Retrieve the list of fields for the given variant of a [type_def]. + + Raises [Invalid_argument] if the arguments are incorrect. + *) +let type_def_get_fields (def : type_def) (opt_variant_id : VariantId.id option) + : field list = + match (def.kind, opt_variant_id) with + | Enum variants, Some variant_id -> (VariantId.nth variants variant_id).fields + | Struct fields, None -> fields + | _ -> + let opt_variant_id = + match opt_variant_id with None -> "None" | Some _ -> "Some" + in + raise + (Invalid_argument + ("The variant id should be [Some] if and only if the definition is \ + an enumeration:\n\ + - def: " ^ show_type_def def ^ "\n- opt_variant_id: " + ^ opt_variant_id)) + +(** Instantiate the type variables for the chosen variant in an ADT definition, + and return the list of the types of its fields *) +let type_def_get_instantiated_fields_types (def : type_def) + (opt_variant_id : VariantId.id option) (types : ty list) : ty list = + let ty_subst = make_type_subst def.type_params types in + let fields = type_def_get_fields def opt_variant_id in + List.map (fun f -> ty_substitute ty_subst f.field_ty) fields + let fun_sig_substitute (tsubst : TypeVarId.id -> ty) (sg : fun_sig) : inst_fun_sig = let subst = ty_substitute tsubst in @@ -228,3 +256,82 @@ let rec expression_requires_parentheses (e : texpression) : bool = if monadic then true else expression_requires_parentheses next_e | Switch (_, _) -> false | Meta (_, next_e) -> expression_requires_parentheses next_e + +(** Module to perform type checking - we use this for sanity checks only *) +module TypeCheck = struct + type tc_ctx = { type_defs : type_def TypeDefId.Map.t } + + let check_constant_value (ty : ty) (v : constant_value) : unit = + match (ty, v) with + | Integer int_ty, V.Scalar sv -> assert (int_ty = sv.V.int_ty) + | Bool, Bool _ | Char, Char _ | Str, String _ -> () + | _ -> raise (Failure "Inconsistent type") + + let check_adt_g_value (ctx : tc_ctx) (check_value : ty -> 'v -> unit) + (variant_id : VariantId.id option) (field_values : 'v list) (ty : ty) : + unit = + (* let field_values = List.map value_to_string field_values in *) + (* Retrieve the field types *) + let field_tys = + match ty with + | Adt (Tuple, tys) -> + (* Tuple *) + tys + | Adt (AdtId def_id, tys) -> + (* "Regular" ADT *) + let def = TypeDefId.Map.find def_id ctx.type_defs in + type_def_get_instantiated_fields_types def variant_id tys + | Adt (Assumed aty, tys) -> ( + (* Assumed type *) + match aty with + | Result -> + let ty = Collections.List.to_cons_nil tys in + let variant_id = Option.get variant_id in + if variant_id = result_return_id then [ ty ] + else if variant_id = result_fail_id then [] + else + raise + (Failure "Unreachable: improper variant id for result type") + | Option -> + let ty = Collections.List.to_cons_nil tys in + let variant_id = Option.get variant_id in + if variant_id = option_some_id then [ ty ] + else if variant_id = option_none_id then [] + else + raise + (Failure "Unreachable: improper variant id for result type") + | Vec -> + assert (variant_id = None); + let ty = Collections.List.to_cons_nil tys in + List.map (fun _ -> ty) field_values) + | _ -> raise (Failure "Inconsistently typed value") + in + (* Check that the field values have the expected types *) + List.iter + (fun (ty, v) -> check_value ty v) + (List.combine field_tys field_values) + + let rec check_typed_lvalue (ctx : tc_ctx) (v : typed_lvalue) : unit = + match v.value with + | LvConcrete cv -> check_constant_value v.ty cv + | LvVar _ -> () + | LvAdt av -> + check_adt_g_value ctx + (fun ty (v : typed_lvalue) -> + assert (ty = v.ty); + check_typed_lvalue ctx v) + av.variant_id av.field_values v.ty + + let rec check_typed_rvalue (ctx : tc_ctx) (v : typed_rvalue) : unit = + match v.value with + | RvConcrete cv -> check_constant_value v.ty cv + | RvPlace _ -> + (* TODO: *) + () + | RvAdt av -> + check_adt_g_value ctx + (fun ty (v : typed_rvalue) -> + assert (ty = v.ty); + check_typed_rvalue ctx v) + av.variant_id av.field_values v.ty +end diff --git a/src/SymbolicToPure.ml b/src/SymbolicToPure.ml index 46d2205c..0a4d0176 100644 --- a/src/SymbolicToPure.ml +++ b/src/SymbolicToPure.ml @@ -14,6 +14,12 @@ let log = L.symbolic_to_pure_log type type_context = { cfim_type_defs : T.type_def TypeDefId.Map.t; + type_defs : type_def TypeDefId.Map.t; + (** We use this for type-checking (for sanity checks) when translating + values and functions. + This map is empty when we translate the types, then contains all + the translated types when we translate the functions. + *) types_infos : TA.type_infos; (* TODO: rename to type_infos *) } @@ -71,6 +77,14 @@ type bs_ctx = { } (** Body synthesis context *) +let type_check_rvalue (ctx : bs_ctx) (v : typed_rvalue) : unit = + let ctx = { TypeCheck.type_defs = ctx.type_context.type_defs } in + TypeCheck.check_typed_rvalue ctx v + +let type_check_lvalue (ctx : bs_ctx) (v : typed_lvalue) : unit = + let ctx = { TypeCheck.type_defs = ctx.type_context.type_defs } in + TypeCheck.check_typed_lvalue ctx v + (* TODO: move *) let bs_ctx_to_ast_formatter (ctx : bs_ctx) : Print.CfimAst.ast_formatter = Print.CfimAst.fun_def_to_ast_formatter ctx.type_context.cfim_type_defs @@ -587,6 +601,9 @@ let rec typed_value_to_rvalue (ctx : bs_ctx) (v : V.typed_value) : typed_rvalue in let ty = ctx_translate_fwd_ty ctx v.ty in let value = { value; ty } in + (* Sanity check *) + type_check_rvalue ctx value; + (* Return *) value (** Explore an abstraction value and convert it to a consumed value @@ -607,30 +624,37 @@ let rec typed_value_to_rvalue (ctx : bs_ctx) (v : V.typed_value) : typed_rvalue let rec typed_avalue_to_consumed (ctx : bs_ctx) (av : V.typed_avalue) : typed_rvalue option = let translate = typed_avalue_to_consumed ctx in - match av.value with - | AConcrete _ -> failwith "Unreachable" - | AAdt adt_v -> ( - (* Translate the field values *) - let field_values = List.filter_map translate adt_v.field_values in - (* For now, only tuples can contain borrows *) - let adt_id, _, _ = TypesUtils.ty_as_adt av.ty in - match adt_id with - | T.AdtId _ | T.Assumed (T.Box | T.Vec | T.Option) -> - assert (field_values = []); - None - | T.Tuple -> - (* Return *) - if field_values = [] then None - else - (* Note that if there is exactly one field value, - * [mk_simpl_tuple_rvalue] is the identity *) - let rv = mk_simpl_tuple_rvalue field_values in - Some rv) - | ABottom -> failwith "Unreachable" - | ALoan lc -> aloan_content_to_consumed ctx lc - | ABorrow bc -> aborrow_content_to_consumed ctx bc - | ASymbolic aproj -> aproj_to_consumed ctx aproj - | AIgnored -> None + let value = + match av.value with + | AConcrete _ -> failwith "Unreachable" + | AAdt adt_v -> ( + (* Translate the field values *) + let field_values = List.filter_map translate adt_v.field_values in + (* For now, only tuples can contain borrows *) + let adt_id, _, _ = TypesUtils.ty_as_adt av.ty in + match adt_id with + | T.AdtId _ | T.Assumed (T.Box | T.Vec | T.Option) -> + assert (field_values = []); + None + | T.Tuple -> + (* Return *) + if field_values = [] then None + else + (* Note that if there is exactly one field value, + * [mk_simpl_tuple_rvalue] is the identity *) + let rv = mk_simpl_tuple_rvalue field_values in + Some rv) + | ABottom -> failwith "Unreachable" + | ALoan lc -> aloan_content_to_consumed ctx lc + | ABorrow bc -> aborrow_content_to_consumed ctx bc + | ASymbolic aproj -> aproj_to_consumed ctx aproj + | AIgnored -> None + in + (* Sanity check - Rk.: we do this at every recursive call, which is a bit + * expansive... *) + (match value with None -> () | Some value -> type_check_rvalue ctx value); + (* Return *) + value and aloan_content_to_consumed (ctx : bs_ctx) (lc : V.aloan_content) : typed_rvalue option = @@ -731,43 +755,50 @@ let translate_opt_mplace (p : S.mplace option) : mplace option = *) let rec typed_avalue_to_given_back (mp : mplace option) (av : V.typed_avalue) (ctx : bs_ctx) : bs_ctx * typed_lvalue option = - match av.value with - | AConcrete _ -> failwith "Unreachable" - | AAdt adt_v -> ( - (* Translate the field values *) - (* For now we forget the meta-place information so that it doesn't get used - * by several fields (which would then all have the same name...), but we - * might want to do something smarter *) - let mp = None in - let ctx, field_values = - List.fold_left_map - (fun ctx fv -> typed_avalue_to_given_back mp fv ctx) - ctx adt_v.field_values - in - let field_values = List.filter_map (fun x -> x) field_values in - (* For now, only tuples can contain borrows - note that if we gave - * something like a `&mut Vec` to a function, we give give back the - * vector value upon visiting the "abstraction borrow" node *) - let adt_id, _, _ = TypesUtils.ty_as_adt av.ty in - match adt_id with - | T.AdtId _ | T.Assumed (T.Box | T.Vec | T.Option) -> - assert (field_values = []); - (ctx, None) - | T.Tuple -> - (* Return *) - let variant_id = adt_v.variant_id in - assert (variant_id = None); - if field_values = [] then (ctx, None) - else - (* Note that if there is exactly one field value, [mk_simpl_tuple_lvalue] - * is the identity *) - let lv = mk_simpl_tuple_lvalue field_values in - (ctx, Some lv)) - | ABottom -> failwith "Unreachable" - | ALoan lc -> aloan_content_to_given_back mp lc ctx - | ABorrow bc -> aborrow_content_to_given_back mp bc ctx - | ASymbolic aproj -> aproj_to_given_back mp aproj ctx - | AIgnored -> (ctx, None) + let ctx, value = + match av.value with + | AConcrete _ -> failwith "Unreachable" + | AAdt adt_v -> ( + (* Translate the field values *) + (* For now we forget the meta-place information so that it doesn't get used + * by several fields (which would then all have the same name...), but we + * might want to do something smarter *) + let mp = None in + let ctx, field_values = + List.fold_left_map + (fun ctx fv -> typed_avalue_to_given_back mp fv ctx) + ctx adt_v.field_values + in + let field_values = List.filter_map (fun x -> x) field_values in + (* For now, only tuples can contain borrows - note that if we gave + * something like a `&mut Vec` to a function, we give give back the + * vector value upon visiting the "abstraction borrow" node *) + let adt_id, _, _ = TypesUtils.ty_as_adt av.ty in + match adt_id with + | T.AdtId _ | T.Assumed (T.Box | T.Vec | T.Option) -> + assert (field_values = []); + (ctx, None) + | T.Tuple -> + (* Return *) + let variant_id = adt_v.variant_id in + assert (variant_id = None); + if field_values = [] then (ctx, None) + else + (* Note that if there is exactly one field value, [mk_simpl_tuple_lvalue] + * is the identity *) + let lv = mk_simpl_tuple_lvalue field_values in + (ctx, Some lv)) + | ABottom -> failwith "Unreachable" + | ALoan lc -> aloan_content_to_given_back mp lc ctx + | ABorrow bc -> aborrow_content_to_given_back mp bc ctx + | ASymbolic aproj -> aproj_to_given_back mp aproj ctx + | AIgnored -> (ctx, None) + in + (* Sanity check - Rk.: we do this at every recursive call, which is a bit + * expansive... *) + (match value with None -> () | Some value -> type_check_lvalue ctx value); + (* Return *) + (ctx, value) and aloan_content_to_given_back (_mp : mplace option) (lc : V.aloan_content) (ctx : bs_ctx) : bs_ctx * typed_lvalue option = diff --git a/src/Translate.ml b/src/Translate.ml index ba975c60..028114cf 100644 --- a/src/Translate.ml +++ b/src/Translate.ml @@ -61,7 +61,8 @@ let translate_function_to_symbolics (config : C.partial_config) let translate_function_to_pure (config : C.partial_config) (trans_ctx : trans_ctx) (fun_sigs : SymbolicToPure.fun_sig_named_outputs RegularFunIdMap.t) - (fdef : A.fun_def) : pure_fun_translation = + (pure_type_defs : Pure.type_def Pure.TypeDefId.Map.t) (fdef : A.fun_def) : + pure_fun_translation = (* Debug *) log#ldebug (lazy ("translate_function_to_pure: " ^ Print.name_to_string fdef.A.name)); @@ -91,6 +92,7 @@ let translate_function_to_pure (config : C.partial_config) { SymbolicToPure.types_infos = type_context.type_infos; cfim_type_defs = type_context.type_defs; + type_defs = pure_type_defs; } in let fun_context = @@ -216,6 +218,12 @@ let translate_module_to_pure (config : C.partial_config) (* Translate all the type definitions *) let type_defs = SymbolicToPure.translate_type_defs m.types in + (* Compute the type definition map *) + let type_defs_map = + Pure.TypeDefId.Map.of_list + (List.map (fun (def : Pure.type_def) -> (def.def_id, def)) type_defs) + in + (* Translate all the function *signatures* *) let assumed_sigs = List.map @@ -240,7 +248,9 @@ let translate_module_to_pure (config : C.partial_config) (* Translate all the functions *) let pure_translations = - List.map (translate_function_to_pure config trans_ctx fun_sigs) m.functions + List.map + (translate_function_to_pure config trans_ctx fun_sigs type_defs_map) + m.functions in (* Apply the micro-passes *) |