diff options
Diffstat (limited to 'src/PureUtils.ml')
-rw-r--r-- | src/PureUtils.ml | 107 |
1 files changed, 107 insertions, 0 deletions
diff --git a/src/PureUtils.ml b/src/PureUtils.ml index 6ae8184d..662902e6 100644 --- a/src/PureUtils.ml +++ b/src/PureUtils.ml @@ -170,6 +170,34 @@ let make_type_subst (vars : type_var list) (tys : ty list) : TypeVarId.id -> ty in fun id -> TypeVarId.Map.find id mp +(** Retrieve the list of fields for the given variant of a [type_def]. + + Raises [Invalid_argument] if the arguments are incorrect. + *) +let type_def_get_fields (def : type_def) (opt_variant_id : VariantId.id option) + : field list = + match (def.kind, opt_variant_id) with + | Enum variants, Some variant_id -> (VariantId.nth variants variant_id).fields + | Struct fields, None -> fields + | _ -> + let opt_variant_id = + match opt_variant_id with None -> "None" | Some _ -> "Some" + in + raise + (Invalid_argument + ("The variant id should be [Some] if and only if the definition is \ + an enumeration:\n\ + - def: " ^ show_type_def def ^ "\n- opt_variant_id: " + ^ opt_variant_id)) + +(** Instantiate the type variables for the chosen variant in an ADT definition, + and return the list of the types of its fields *) +let type_def_get_instantiated_fields_types (def : type_def) + (opt_variant_id : VariantId.id option) (types : ty list) : ty list = + let ty_subst = make_type_subst def.type_params types in + let fields = type_def_get_fields def opt_variant_id in + List.map (fun f -> ty_substitute ty_subst f.field_ty) fields + let fun_sig_substitute (tsubst : TypeVarId.id -> ty) (sg : fun_sig) : inst_fun_sig = let subst = ty_substitute tsubst in @@ -228,3 +256,82 @@ let rec expression_requires_parentheses (e : texpression) : bool = if monadic then true else expression_requires_parentheses next_e | Switch (_, _) -> false | Meta (_, next_e) -> expression_requires_parentheses next_e + +(** Module to perform type checking - we use this for sanity checks only *) +module TypeCheck = struct + type tc_ctx = { type_defs : type_def TypeDefId.Map.t } + + let check_constant_value (ty : ty) (v : constant_value) : unit = + match (ty, v) with + | Integer int_ty, V.Scalar sv -> assert (int_ty = sv.V.int_ty) + | Bool, Bool _ | Char, Char _ | Str, String _ -> () + | _ -> raise (Failure "Inconsistent type") + + let check_adt_g_value (ctx : tc_ctx) (check_value : ty -> 'v -> unit) + (variant_id : VariantId.id option) (field_values : 'v list) (ty : ty) : + unit = + (* let field_values = List.map value_to_string field_values in *) + (* Retrieve the field types *) + let field_tys = + match ty with + | Adt (Tuple, tys) -> + (* Tuple *) + tys + | Adt (AdtId def_id, tys) -> + (* "Regular" ADT *) + let def = TypeDefId.Map.find def_id ctx.type_defs in + type_def_get_instantiated_fields_types def variant_id tys + | Adt (Assumed aty, tys) -> ( + (* Assumed type *) + match aty with + | Result -> + let ty = Collections.List.to_cons_nil tys in + let variant_id = Option.get variant_id in + if variant_id = result_return_id then [ ty ] + else if variant_id = result_fail_id then [] + else + raise + (Failure "Unreachable: improper variant id for result type") + | Option -> + let ty = Collections.List.to_cons_nil tys in + let variant_id = Option.get variant_id in + if variant_id = option_some_id then [ ty ] + else if variant_id = option_none_id then [] + else + raise + (Failure "Unreachable: improper variant id for result type") + | Vec -> + assert (variant_id = None); + let ty = Collections.List.to_cons_nil tys in + List.map (fun _ -> ty) field_values) + | _ -> raise (Failure "Inconsistently typed value") + in + (* Check that the field values have the expected types *) + List.iter + (fun (ty, v) -> check_value ty v) + (List.combine field_tys field_values) + + let rec check_typed_lvalue (ctx : tc_ctx) (v : typed_lvalue) : unit = + match v.value with + | LvConcrete cv -> check_constant_value v.ty cv + | LvVar _ -> () + | LvAdt av -> + check_adt_g_value ctx + (fun ty (v : typed_lvalue) -> + assert (ty = v.ty); + check_typed_lvalue ctx v) + av.variant_id av.field_values v.ty + + let rec check_typed_rvalue (ctx : tc_ctx) (v : typed_rvalue) : unit = + match v.value with + | RvConcrete cv -> check_constant_value v.ty cv + | RvPlace _ -> + (* TODO: *) + () + | RvAdt av -> + check_adt_g_value ctx + (fun ty (v : typed_rvalue) -> + assert (ty = v.ty); + check_typed_rvalue ctx v) + av.variant_id av.field_values v.ty +end |