diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/PureUtils.ml | 33 |
1 files changed, 23 insertions, 10 deletions
diff --git a/src/PureUtils.ml b/src/PureUtils.ml index 73794a7c..96982b4b 100644 --- a/src/PureUtils.ml +++ b/src/PureUtils.ml @@ -560,8 +560,8 @@ module TypeCheck = struct raise (Failure "Inconsistent types")); check_typed_pattern ctx v in - (* Check the field types - TODO: we might also want to check that the - * type of the applied constructor is correct *) + (* Check the field types: check that the field patterns have the expected + * types, and check that the field patterns themselves are well-typed *) List.fold_left (fun ctx (ty, v) -> check_value ctx ty v) ctx @@ -594,25 +594,38 @@ module TypeCheck = struct | Qualif qualif -> ( match qualif.id with | Func _ -> () (* TODO *) - | Proj { adt_id; field_id } -> + | Proj { adt_id = proj_adt_id; field_id } -> (* Note we can only project fields of structures (not enumerations) *) + (* Deconstruct the projector type *) + let adt_ty, field_ty = destruct_arrow e.ty in + let adt_id, adt_type_args = + match adt_ty with + | Adt (type_id, tys) -> (type_id, tys) + | _ -> raise (Failure "Unreachable") + in + (* Check the ADT type *) + assert (adt_id = proj_adt_id); + assert (adt_type_args = qualif.type_args); + (* Retrieve and check the expected field type *) let variant_id = None in let expected_field_tys = - get_adt_field_types ctx.type_decls adt_id variant_id + get_adt_field_types ctx.type_decls proj_adt_id variant_id qualif.type_args in let expected_field_ty = FieldId.nth expected_field_tys field_id in - let _adt_ty, field_ty = destruct_arrow e.ty in - (* TODO: check the adt_ty *) assert (expected_field_ty = field_ty) - | AdtCons id -> - (* TODO: we might also want to check the out type *) + | AdtCons id -> ( let expected_field_tys = get_adt_field_types ctx.type_decls id.adt_id id.variant_id qualif.type_args in - let field_tys, _ = destruct_arrows e.ty in - assert (expected_field_tys = field_tys)) + let field_tys, adt_ty = destruct_arrows e.ty in + assert (expected_field_tys = field_tys); + match adt_ty with + | Adt (type_id, tys) -> + assert (type_id = id.adt_id); + assert (tys = qualif.type_args) + | _ -> raise (Failure "Unreachable"))) | Let (monadic, pat, re, e_next) -> let expected_pat_ty = if monadic then destruct_result re.ty else re.ty |