summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorSon Ho2022-02-08 23:00:17 +0100
committerSon Ho2022-02-08 23:00:17 +0100
commit5703ce3122bcfb69285a7f04abc8d80313a0747a (patch)
treea424f6dc1bb0598e3e47f1a3cc2ec4e15607dc91 /src
parent229a9881fa26dce69b81524445045e7b1efcc6fc (diff)
Add type checking utilities for the pure ADT
Diffstat (limited to 'src')
-rw-r--r--src/PrintPure.ml18
-rw-r--r--src/PureUtils.ml107
-rw-r--r--src/SymbolicToPure.ml153
-rw-r--r--src/Translate.ml14
4 files changed, 224 insertions, 68 deletions
diff --git a/src/PrintPure.ml b/src/PrintPure.ml
index cf865a54..f66aadfb 100644
--- a/src/PrintPure.ml
+++ b/src/PrintPure.ml
@@ -262,28 +262,36 @@ let adt_g_value_to_string (fmt : value_formatter)
if variant_id = result_return_id then
match field_values with
| [ v ] -> "@Result::Return " ^ v
- | _ -> failwith "Result::Return takes exactly one value"
+ | _ -> raise (Failure "Result::Return takes exactly one value")
else if variant_id = result_fail_id then (
assert (field_values = []);
"@Result::Fail")
- else failwith "Unreachable: improper variant id for result type"
+ else
+ raise (Failure "Unreachable: improper variant id for result type")
| Option ->
let variant_id = Option.get variant_id in
if variant_id = option_some_id then
match field_values with
| [ v ] -> "@Option::Some " ^ v
- | _ -> failwith "Option::Some takes exactly one value"
+ | _ -> raise (Failure "Option::Some takes exactly one value")
else if variant_id = option_none_id then (
assert (field_values = []);
"@Option::None")
- else failwith "Unreachable: improper variant id for result type"
+ else
+ raise (Failure "Unreachable: improper variant id for result type")
| Vec ->
assert (variant_id = None);
let field_values =
List.mapi (fun i v -> string_of_int i ^ " -> " ^ v) field_values
in
"Vec [" ^ String.concat "; " field_values ^ "]")
- | _ -> failwith "Inconsistent typed value"
+ | _ ->
+ let fmt = value_to_type_formatter fmt in
+ raise
+ (Failure
+ ("Inconsistently typed value: expected ADT type but found:"
+ ^ "\n- ty: " ^ ty_to_string fmt ty ^ "\n- variant_id: "
+ ^ Print.option_to_string VariantId.to_string variant_id))
let rec typed_lvalue_to_string (fmt : value_formatter) (v : typed_lvalue) :
string =
diff --git a/src/PureUtils.ml b/src/PureUtils.ml
index 6ae8184d..662902e6 100644
--- a/src/PureUtils.ml
+++ b/src/PureUtils.ml
@@ -170,6 +170,34 @@ let make_type_subst (vars : type_var list) (tys : ty list) : TypeVarId.id -> ty
in
fun id -> TypeVarId.Map.find id mp
+(** Retrieve the list of fields for the given variant of a [type_def].
+
+ Raises [Invalid_argument] if the arguments are incorrect.
+ *)
+let type_def_get_fields (def : type_def) (opt_variant_id : VariantId.id option)
+ : field list =
+ match (def.kind, opt_variant_id) with
+ | Enum variants, Some variant_id -> (VariantId.nth variants variant_id).fields
+ | Struct fields, None -> fields
+ | _ ->
+ let opt_variant_id =
+ match opt_variant_id with None -> "None" | Some _ -> "Some"
+ in
+ raise
+ (Invalid_argument
+ ("The variant id should be [Some] if and only if the definition is \
+ an enumeration:\n\
+ - def: " ^ show_type_def def ^ "\n- opt_variant_id: "
+ ^ opt_variant_id))
+
+(** Instantiate the type variables for the chosen variant in an ADT definition,
+ and return the list of the types of its fields *)
+let type_def_get_instantiated_fields_types (def : type_def)
+ (opt_variant_id : VariantId.id option) (types : ty list) : ty list =
+ let ty_subst = make_type_subst def.type_params types in
+ let fields = type_def_get_fields def opt_variant_id in
+ List.map (fun f -> ty_substitute ty_subst f.field_ty) fields
+
let fun_sig_substitute (tsubst : TypeVarId.id -> ty) (sg : fun_sig) :
inst_fun_sig =
let subst = ty_substitute tsubst in
@@ -228,3 +256,82 @@ let rec expression_requires_parentheses (e : texpression) : bool =
if monadic then true else expression_requires_parentheses next_e
| Switch (_, _) -> false
| Meta (_, next_e) -> expression_requires_parentheses next_e
+
+(** Module to perform type checking - we use this for sanity checks only *)
+module TypeCheck = struct
+ type tc_ctx = { type_defs : type_def TypeDefId.Map.t }
+
+ let check_constant_value (ty : ty) (v : constant_value) : unit =
+ match (ty, v) with
+ | Integer int_ty, V.Scalar sv -> assert (int_ty = sv.V.int_ty)
+ | Bool, Bool _ | Char, Char _ | Str, String _ -> ()
+ | _ -> raise (Failure "Inconsistent type")
+
+ let check_adt_g_value (ctx : tc_ctx) (check_value : ty -> 'v -> unit)
+ (variant_id : VariantId.id option) (field_values : 'v list) (ty : ty) :
+ unit =
+ (* let field_values = List.map value_to_string field_values in *)
+ (* Retrieve the field types *)
+ let field_tys =
+ match ty with
+ | Adt (Tuple, tys) ->
+ (* Tuple *)
+ tys
+ | Adt (AdtId def_id, tys) ->
+ (* "Regular" ADT *)
+ let def = TypeDefId.Map.find def_id ctx.type_defs in
+ type_def_get_instantiated_fields_types def variant_id tys
+ | Adt (Assumed aty, tys) -> (
+ (* Assumed type *)
+ match aty with
+ | Result ->
+ let ty = Collections.List.to_cons_nil tys in
+ let variant_id = Option.get variant_id in
+ if variant_id = result_return_id then [ ty ]
+ else if variant_id = result_fail_id then []
+ else
+ raise
+ (Failure "Unreachable: improper variant id for result type")
+ | Option ->
+ let ty = Collections.List.to_cons_nil tys in
+ let variant_id = Option.get variant_id in
+ if variant_id = option_some_id then [ ty ]
+ else if variant_id = option_none_id then []
+ else
+ raise
+ (Failure "Unreachable: improper variant id for result type")
+ | Vec ->
+ assert (variant_id = None);
+ let ty = Collections.List.to_cons_nil tys in
+ List.map (fun _ -> ty) field_values)
+ | _ -> raise (Failure "Inconsistently typed value")
+ in
+ (* Check that the field values have the expected types *)
+ List.iter
+ (fun (ty, v) -> check_value ty v)
+ (List.combine field_tys field_values)
+
+ let rec check_typed_lvalue (ctx : tc_ctx) (v : typed_lvalue) : unit =
+ match v.value with
+ | LvConcrete cv -> check_constant_value v.ty cv
+ | LvVar _ -> ()
+ | LvAdt av ->
+ check_adt_g_value ctx
+ (fun ty (v : typed_lvalue) ->
+ assert (ty = v.ty);
+ check_typed_lvalue ctx v)
+ av.variant_id av.field_values v.ty
+
+ let rec check_typed_rvalue (ctx : tc_ctx) (v : typed_rvalue) : unit =
+ match v.value with
+ | RvConcrete cv -> check_constant_value v.ty cv
+ | RvPlace _ ->
+ (* TODO: *)
+ ()
+ | RvAdt av ->
+ check_adt_g_value ctx
+ (fun ty (v : typed_rvalue) ->
+ assert (ty = v.ty);
+ check_typed_rvalue ctx v)
+ av.variant_id av.field_values v.ty
+end
diff --git a/src/SymbolicToPure.ml b/src/SymbolicToPure.ml
index 46d2205c..0a4d0176 100644
--- a/src/SymbolicToPure.ml
+++ b/src/SymbolicToPure.ml
@@ -14,6 +14,12 @@ let log = L.symbolic_to_pure_log
type type_context = {
cfim_type_defs : T.type_def TypeDefId.Map.t;
+ type_defs : type_def TypeDefId.Map.t;
+ (** We use this for type-checking (for sanity checks) when translating
+ values and functions.
+ This map is empty when we translate the types, then contains all
+ the translated types when we translate the functions.
+ *)
types_infos : TA.type_infos; (* TODO: rename to type_infos *)
}
@@ -71,6 +77,14 @@ type bs_ctx = {
}
(** Body synthesis context *)
+let type_check_rvalue (ctx : bs_ctx) (v : typed_rvalue) : unit =
+ let ctx = { TypeCheck.type_defs = ctx.type_context.type_defs } in
+ TypeCheck.check_typed_rvalue ctx v
+
+let type_check_lvalue (ctx : bs_ctx) (v : typed_lvalue) : unit =
+ let ctx = { TypeCheck.type_defs = ctx.type_context.type_defs } in
+ TypeCheck.check_typed_lvalue ctx v
+
(* TODO: move *)
let bs_ctx_to_ast_formatter (ctx : bs_ctx) : Print.CfimAst.ast_formatter =
Print.CfimAst.fun_def_to_ast_formatter ctx.type_context.cfim_type_defs
@@ -587,6 +601,9 @@ let rec typed_value_to_rvalue (ctx : bs_ctx) (v : V.typed_value) : typed_rvalue
in
let ty = ctx_translate_fwd_ty ctx v.ty in
let value = { value; ty } in
+ (* Sanity check *)
+ type_check_rvalue ctx value;
+ (* Return *)
value
(** Explore an abstraction value and convert it to a consumed value
@@ -607,30 +624,37 @@ let rec typed_value_to_rvalue (ctx : bs_ctx) (v : V.typed_value) : typed_rvalue
let rec typed_avalue_to_consumed (ctx : bs_ctx) (av : V.typed_avalue) :
typed_rvalue option =
let translate = typed_avalue_to_consumed ctx in
- match av.value with
- | AConcrete _ -> failwith "Unreachable"
- | AAdt adt_v -> (
- (* Translate the field values *)
- let field_values = List.filter_map translate adt_v.field_values in
- (* For now, only tuples can contain borrows *)
- let adt_id, _, _ = TypesUtils.ty_as_adt av.ty in
- match adt_id with
- | T.AdtId _ | T.Assumed (T.Box | T.Vec | T.Option) ->
- assert (field_values = []);
- None
- | T.Tuple ->
- (* Return *)
- if field_values = [] then None
- else
- (* Note that if there is exactly one field value,
- * [mk_simpl_tuple_rvalue] is the identity *)
- let rv = mk_simpl_tuple_rvalue field_values in
- Some rv)
- | ABottom -> failwith "Unreachable"
- | ALoan lc -> aloan_content_to_consumed ctx lc
- | ABorrow bc -> aborrow_content_to_consumed ctx bc
- | ASymbolic aproj -> aproj_to_consumed ctx aproj
- | AIgnored -> None
+ let value =
+ match av.value with
+ | AConcrete _ -> failwith "Unreachable"
+ | AAdt adt_v -> (
+ (* Translate the field values *)
+ let field_values = List.filter_map translate adt_v.field_values in
+ (* For now, only tuples can contain borrows *)
+ let adt_id, _, _ = TypesUtils.ty_as_adt av.ty in
+ match adt_id with
+ | T.AdtId _ | T.Assumed (T.Box | T.Vec | T.Option) ->
+ assert (field_values = []);
+ None
+ | T.Tuple ->
+ (* Return *)
+ if field_values = [] then None
+ else
+ (* Note that if there is exactly one field value,
+ * [mk_simpl_tuple_rvalue] is the identity *)
+ let rv = mk_simpl_tuple_rvalue field_values in
+ Some rv)
+ | ABottom -> failwith "Unreachable"
+ | ALoan lc -> aloan_content_to_consumed ctx lc
+ | ABorrow bc -> aborrow_content_to_consumed ctx bc
+ | ASymbolic aproj -> aproj_to_consumed ctx aproj
+ | AIgnored -> None
+ in
+ (* Sanity check - Rk.: we do this at every recursive call, which is a bit
+ * expansive... *)
+ (match value with None -> () | Some value -> type_check_rvalue ctx value);
+ (* Return *)
+ value
and aloan_content_to_consumed (ctx : bs_ctx) (lc : V.aloan_content) :
typed_rvalue option =
@@ -731,43 +755,50 @@ let translate_opt_mplace (p : S.mplace option) : mplace option =
*)
let rec typed_avalue_to_given_back (mp : mplace option) (av : V.typed_avalue)
(ctx : bs_ctx) : bs_ctx * typed_lvalue option =
- match av.value with
- | AConcrete _ -> failwith "Unreachable"
- | AAdt adt_v -> (
- (* Translate the field values *)
- (* For now we forget the meta-place information so that it doesn't get used
- * by several fields (which would then all have the same name...), but we
- * might want to do something smarter *)
- let mp = None in
- let ctx, field_values =
- List.fold_left_map
- (fun ctx fv -> typed_avalue_to_given_back mp fv ctx)
- ctx adt_v.field_values
- in
- let field_values = List.filter_map (fun x -> x) field_values in
- (* For now, only tuples can contain borrows - note that if we gave
- * something like a `&mut Vec` to a function, we give give back the
- * vector value upon visiting the "abstraction borrow" node *)
- let adt_id, _, _ = TypesUtils.ty_as_adt av.ty in
- match adt_id with
- | T.AdtId _ | T.Assumed (T.Box | T.Vec | T.Option) ->
- assert (field_values = []);
- (ctx, None)
- | T.Tuple ->
- (* Return *)
- let variant_id = adt_v.variant_id in
- assert (variant_id = None);
- if field_values = [] then (ctx, None)
- else
- (* Note that if there is exactly one field value, [mk_simpl_tuple_lvalue]
- * is the identity *)
- let lv = mk_simpl_tuple_lvalue field_values in
- (ctx, Some lv))
- | ABottom -> failwith "Unreachable"
- | ALoan lc -> aloan_content_to_given_back mp lc ctx
- | ABorrow bc -> aborrow_content_to_given_back mp bc ctx
- | ASymbolic aproj -> aproj_to_given_back mp aproj ctx
- | AIgnored -> (ctx, None)
+ let ctx, value =
+ match av.value with
+ | AConcrete _ -> failwith "Unreachable"
+ | AAdt adt_v -> (
+ (* Translate the field values *)
+ (* For now we forget the meta-place information so that it doesn't get used
+ * by several fields (which would then all have the same name...), but we
+ * might want to do something smarter *)
+ let mp = None in
+ let ctx, field_values =
+ List.fold_left_map
+ (fun ctx fv -> typed_avalue_to_given_back mp fv ctx)
+ ctx adt_v.field_values
+ in
+ let field_values = List.filter_map (fun x -> x) field_values in
+ (* For now, only tuples can contain borrows - note that if we gave
+ * something like a `&mut Vec` to a function, we give give back the
+ * vector value upon visiting the "abstraction borrow" node *)
+ let adt_id, _, _ = TypesUtils.ty_as_adt av.ty in
+ match adt_id with
+ | T.AdtId _ | T.Assumed (T.Box | T.Vec | T.Option) ->
+ assert (field_values = []);
+ (ctx, None)
+ | T.Tuple ->
+ (* Return *)
+ let variant_id = adt_v.variant_id in
+ assert (variant_id = None);
+ if field_values = [] then (ctx, None)
+ else
+ (* Note that if there is exactly one field value, [mk_simpl_tuple_lvalue]
+ * is the identity *)
+ let lv = mk_simpl_tuple_lvalue field_values in
+ (ctx, Some lv))
+ | ABottom -> failwith "Unreachable"
+ | ALoan lc -> aloan_content_to_given_back mp lc ctx
+ | ABorrow bc -> aborrow_content_to_given_back mp bc ctx
+ | ASymbolic aproj -> aproj_to_given_back mp aproj ctx
+ | AIgnored -> (ctx, None)
+ in
+ (* Sanity check - Rk.: we do this at every recursive call, which is a bit
+ * expansive... *)
+ (match value with None -> () | Some value -> type_check_lvalue ctx value);
+ (* Return *)
+ (ctx, value)
and aloan_content_to_given_back (_mp : mplace option) (lc : V.aloan_content)
(ctx : bs_ctx) : bs_ctx * typed_lvalue option =
diff --git a/src/Translate.ml b/src/Translate.ml
index ba975c60..028114cf 100644
--- a/src/Translate.ml
+++ b/src/Translate.ml
@@ -61,7 +61,8 @@ let translate_function_to_symbolics (config : C.partial_config)
let translate_function_to_pure (config : C.partial_config)
(trans_ctx : trans_ctx)
(fun_sigs : SymbolicToPure.fun_sig_named_outputs RegularFunIdMap.t)
- (fdef : A.fun_def) : pure_fun_translation =
+ (pure_type_defs : Pure.type_def Pure.TypeDefId.Map.t) (fdef : A.fun_def) :
+ pure_fun_translation =
(* Debug *)
log#ldebug
(lazy ("translate_function_to_pure: " ^ Print.name_to_string fdef.A.name));
@@ -91,6 +92,7 @@ let translate_function_to_pure (config : C.partial_config)
{
SymbolicToPure.types_infos = type_context.type_infos;
cfim_type_defs = type_context.type_defs;
+ type_defs = pure_type_defs;
}
in
let fun_context =
@@ -216,6 +218,12 @@ let translate_module_to_pure (config : C.partial_config)
(* Translate all the type definitions *)
let type_defs = SymbolicToPure.translate_type_defs m.types in
+ (* Compute the type definition map *)
+ let type_defs_map =
+ Pure.TypeDefId.Map.of_list
+ (List.map (fun (def : Pure.type_def) -> (def.def_id, def)) type_defs)
+ in
+
(* Translate all the function *signatures* *)
let assumed_sigs =
List.map
@@ -240,7 +248,9 @@ let translate_module_to_pure (config : C.partial_config)
(* Translate all the functions *)
let pure_translations =
- List.map (translate_function_to_pure config trans_ctx fun_sigs) m.functions
+ List.map
+ (translate_function_to_pure config trans_ctx fun_sigs type_defs_map)
+ m.functions
in
(* Apply the micro-passes *)