summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--src/PureUtils.ml33
1 files changed, 23 insertions, 10 deletions
diff --git a/src/PureUtils.ml b/src/PureUtils.ml
index 73794a7c..96982b4b 100644
--- a/src/PureUtils.ml
+++ b/src/PureUtils.ml
@@ -560,8 +560,8 @@ module TypeCheck = struct
raise (Failure "Inconsistent types"));
check_typed_pattern ctx v
in
- (* Check the field types - TODO: we might also want to check that the
- * type of the applied constructor is correct *)
+ (* Check the field types: check that the field patterns have the expected
+ * types, and check that the field patterns themselves are well-typed *)
List.fold_left
(fun ctx (ty, v) -> check_value ctx ty v)
ctx
@@ -594,25 +594,38 @@ module TypeCheck = struct
| Qualif qualif -> (
match qualif.id with
| Func _ -> () (* TODO *)
- | Proj { adt_id; field_id } ->
+ | Proj { adt_id = proj_adt_id; field_id } ->
(* Note we can only project fields of structures (not enumerations) *)
+ (* Deconstruct the projector type *)
+ let adt_ty, field_ty = destruct_arrow e.ty in
+ let adt_id, adt_type_args =
+ match adt_ty with
+ | Adt (type_id, tys) -> (type_id, tys)
+ | _ -> raise (Failure "Unreachable")
+ in
+ (* Check the ADT type *)
+ assert (adt_id = proj_adt_id);
+ assert (adt_type_args = qualif.type_args);
+ (* Retrieve and check the expected field type *)
let variant_id = None in
let expected_field_tys =
- get_adt_field_types ctx.type_decls adt_id variant_id
+ get_adt_field_types ctx.type_decls proj_adt_id variant_id
qualif.type_args
in
let expected_field_ty = FieldId.nth expected_field_tys field_id in
- let _adt_ty, field_ty = destruct_arrow e.ty in
- (* TODO: check the adt_ty *)
assert (expected_field_ty = field_ty)
- | AdtCons id ->
- (* TODO: we might also want to check the out type *)
+ | AdtCons id -> (
let expected_field_tys =
get_adt_field_types ctx.type_decls id.adt_id id.variant_id
qualif.type_args
in
- let field_tys, _ = destruct_arrows e.ty in
- assert (expected_field_tys = field_tys))
+ let field_tys, adt_ty = destruct_arrows e.ty in
+ assert (expected_field_tys = field_tys);
+ match adt_ty with
+ | Adt (type_id, tys) ->
+ assert (type_id = id.adt_id);
+ assert (tys = qualif.type_args)
+ | _ -> raise (Failure "Unreachable")))
| Let (monadic, pat, re, e_next) ->
let expected_pat_ty =
if monadic then destruct_result re.ty else re.ty