From 50af296306bfee9f0b127dde8abe5fb0ec1b0acb Mon Sep 17 00:00:00 2001 From: Son Ho Date: Tue, 1 Aug 2023 11:16:06 +0200 Subject: Start adding support for const generics --- compiler/PureUtils.ml | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) (limited to 'compiler/PureUtils.ml') diff --git a/compiler/PureUtils.ml b/compiler/PureUtils.ml index 647678c1..88b18e89 100644 --- a/compiler/PureUtils.ml +++ b/compiler/PureUtils.ml @@ -62,18 +62,16 @@ let dest_arrow_ty (ty : ty) : ty * ty = | Arrow (arg_ty, ret_ty) -> (arg_ty, ret_ty) | _ -> raise (Failure "Unreachable") -let compute_primitive_value_ty (cv : primitive_value) : ty = +let compute_literal_ty (cv : literal) : ty = match cv with | PV.Scalar sv -> Integer sv.PV.int_ty | Bool _ -> Bool | Char _ -> Char - | String _ -> Str let var_get_id (v : var) : VarId.id = v.id -let mk_typed_pattern_from_primitive_value (cv : primitive_value) : typed_pattern - = - let ty = compute_primitive_value_ty cv in +let mk_typed_pattern_from_literal (cv : literal) : typed_pattern = + let ty = compute_literal_ty cv in { value = PatConstant cv; ty } let mk_let (monadic : bool) (lv : typed_pattern) (re : texpression) -- cgit v1.2.3 From 9d27e2e27db06eaad7565b55366ca8734b364fca Mon Sep 17 00:00:00 2001 From: Son Ho Date: Wed, 2 Aug 2023 11:03:59 +0200 Subject: Make progress proapagating the changes --- compiler/PureUtils.ml | 92 ++++++++++++++++++++++++++++++--------------------- 1 file changed, 55 insertions(+), 37 deletions(-) (limited to 'compiler/PureUtils.ml') diff --git a/compiler/PureUtils.ml b/compiler/PureUtils.ml index 88b18e89..1c8d8921 100644 --- a/compiler/PureUtils.ml +++ b/compiler/PureUtils.ml @@ -62,7 +62,7 @@ let dest_arrow_ty (ty : ty) : ty * ty = | Arrow (arg_ty, ret_ty) -> (arg_ty, ret_ty) | _ -> raise (Failure "Unreachable") -let compute_literal_ty (cv : literal) : ty = +let compute_literal_type (cv : literal) : literal_type = match cv with | PV.Scalar sv -> Integer sv.PV.int_ty | Bool _ -> Bool @@ -71,7 +71,7 @@ let compute_literal_ty (cv : literal) : ty = let var_get_id (v : var) : VarId.id = v.id let mk_typed_pattern_from_literal (cv : literal) : typed_pattern = - let ty = compute_literal_ty cv in + let ty = Literal (compute_literal_type cv) in { value = PatConstant cv; ty } let mk_let (monadic : bool) (lv : typed_pattern) (re : texpression) @@ -90,11 +90,13 @@ let mk_mplace (var_id : E.VarId.id) (name : string option) { var_id; name; projection } (** Type substitution *) -let ty_substitute (tsubst : TypeVarId.id -> ty) (ty : ty) : ty = +let ty_substitute (tsubst : TypeVarId.id -> ty) + (cgsubst : ConstGenericVarId.id -> const_generic) (ty : ty) : ty = let obj = object inherit [_] map_ty method! visit_TypeVar _ var_id = tsubst var_id + method! visit_ConstGenericVar _ var_id = cgsubst var_id end in obj#visit_ty () ty @@ -109,6 +111,10 @@ let make_type_subst (vars : type_var list) (tys : ty list) : TypeVarId.id -> ty in fun id -> TypeVarId.Map.find id mp +let make_const_generic_subst (vars : const_generic_var list) + (cgs : const_generic list) : ConstGenericVarId.id -> const_generic = + Substitute.make_const_generic_subst_from_vars vars cgs + (** Retrieve the list of fields for the given variant of a {!type:Aeneas.Pure.type_decl}. Raises [Invalid_argument] if the arguments are incorrect. @@ -132,14 +138,17 @@ let type_decl_get_fields (def : type_decl) (** Instantiate the type variables for the chosen variant in an ADT definition, and return the list of the types of its fields *) let type_decl_get_instantiated_fields_types (def : type_decl) - (opt_variant_id : VariantId.id option) (types : ty list) : ty list = + (opt_variant_id : VariantId.id option) (types : ty list) + (cgs : const_generic list) : ty list = let ty_subst = make_type_subst def.type_params types in + let cg_subst = make_const_generic_subst def.const_generic_params cgs in let fields = type_decl_get_fields def opt_variant_id in - List.map (fun f -> ty_substitute ty_subst f.field_ty) fields + List.map (fun f -> ty_substitute ty_subst cg_subst f.field_ty) fields -let fun_sig_substitute (tsubst : TypeVarId.id -> ty) (sg : fun_sig) : +let fun_sig_substitute (tsubst : TypeVarId.id -> ty) + (cgsubst : ConstGenericVarId.id -> const_generic) (sg : fun_sig) : inst_fun_sig = - let subst = ty_substitute tsubst in + let subst = ty_substitute tsubst cgsubst in let inputs = List.map subst sg.inputs in let output = subst sg.output in let doutputs = List.map subst sg.doutputs in @@ -181,9 +190,9 @@ let is_global (e : texpression) : bool = let is_const (e : texpression) : bool = match e.e with Const _ -> true | _ -> false -let ty_as_adt (ty : ty) : type_id * ty list = +let ty_as_adt (ty : ty) : type_id * ty list * const_generic list = match ty with - | Adt (id, tys) -> (id, tys) + | Adt (id, tys, cgs) -> (id, tys, cgs) | _ -> raise (Failure "Unreachable") (** Remove the external occurrences of {!Meta} *) @@ -291,13 +300,19 @@ let opt_destruct_function_call (e : texpression) : let opt_destruct_result (ty : ty) : ty option = match ty with - | Adt (Assumed Result, tys) -> Some (Collections.List.to_cons_nil tys) + | Adt (Assumed Result, tys, cgs) -> + assert (cgs = []); + Some (Collections.List.to_cons_nil tys) | _ -> None let destruct_result (ty : ty) : ty = Option.get (opt_destruct_result ty) let opt_destruct_tuple (ty : ty) : ty list option = - match ty with Adt (Tuple, tys) -> Some tys | _ -> None + match ty with + | Adt (Tuple, tys, cgs) -> + assert (cgs = []); + Some tys + | _ -> None let mk_abs (x : typed_pattern) (e : texpression) : texpression = let ty = Arrow (x.ty, e.ty) in @@ -351,7 +366,7 @@ let iter_switch_body_branches (f : texpression -> unit) (sb : switch_body) : let mk_switch (scrut : texpression) (sb : switch_body) : texpression = (* Sanity check: the scrutinee has the proper type *) (match sb with - | If (_, _) -> assert (scrut.ty = Bool) + | If (_, _) -> assert (scrut.ty = Literal Bool) | Match branches -> List.iter (fun (b : match_branch) -> assert (b.pat.ty = scrut.ty)) @@ -368,14 +383,14 @@ let mk_switch (scrut : texpression) (sb : switch_body) : texpression = - if there is > one type: wrap them in a tuple *) let mk_simpl_tuple_ty (tys : ty list) : ty = - match tys with [ ty ] -> ty | _ -> Adt (Tuple, tys) + match tys with [ ty ] -> ty | _ -> Adt (Tuple, tys, []) -let mk_bool_ty : ty = Bool -let mk_unit_ty : ty = Adt (Tuple, []) +let mk_bool_ty : ty = Literal Bool +let mk_unit_ty : ty = Adt (Tuple, [], []) let mk_unit_rvalue : texpression = let id = AdtCons { adt_id = Tuple; variant_id = None } in - let qualif = { id; type_args = [] } in + let qualif = { id; type_args = []; const_generic_args = [] } in let e = Qualif qualif in let ty = mk_unit_ty in { e; ty } @@ -415,7 +430,7 @@ let mk_simpl_tuple_pattern (vl : typed_pattern list) : typed_pattern = | [ v ] -> v | _ -> let tys = List.map (fun (v : typed_pattern) -> v.ty) vl in - let ty = Adt (Tuple, tys) in + let ty = Adt (Tuple, tys, []) in let value = PatAdt { variant_id = None; field_values = vl } in { value; ty } @@ -426,11 +441,11 @@ let mk_simpl_tuple_texpression (vl : texpression list) : texpression = | _ -> (* Compute the types of the fields, and the type of the tuple constructor *) let tys = List.map (fun (v : texpression) -> v.ty) vl in - let ty = Adt (Tuple, tys) in + let ty = Adt (Tuple, tys, []) in let ty = mk_arrows tys ty in (* Construct the tuple constructor qualifier *) let id = AdtCons { adt_id = Tuple; variant_id = None } in - let qualif = { id; type_args = tys } in + let qualif = { id; type_args = tys; const_generic_args = [] } in (* Put everything together *) let cons = { e = Qualif qualif; ty } in mk_apps cons vl @@ -441,36 +456,39 @@ let mk_adt_pattern (adt_ty : ty) (variant_id : VariantId.id option) { value; ty = adt_ty } let ty_as_integer (t : ty) : T.integer_type = - match t with Integer int_ty -> int_ty | _ -> raise (Failure "Unreachable") + match t with + | Literal (Integer int_ty) -> int_ty + | _ -> raise (Failure "Unreachable") -(* TODO: move *) -let type_decl_is_enum (def : T.type_decl) : bool = - match def.kind with T.Struct _ -> false | Enum _ -> true | Opaque -> false +let ty_as_literal (t : ty) : T.literal_type = + match t with Literal ty -> ty | _ -> raise (Failure "Unreachable") -let mk_state_ty : ty = Adt (Assumed State, []) -let mk_result_ty (ty : ty) : ty = Adt (Assumed Result, [ ty ]) -let mk_error_ty : ty = Adt (Assumed Error, []) -let mk_fuel_ty : ty = Adt (Assumed Fuel, []) +let mk_state_ty : ty = Adt (Assumed State, [], []) +let mk_result_ty (ty : ty) : ty = Adt (Assumed Result, [ ty ], []) +let mk_error_ty : ty = Adt (Assumed Error, [], []) +let mk_fuel_ty : ty = Adt (Assumed Fuel, [], []) let mk_error (error : VariantId.id) : texpression = let ty = mk_error_ty in let id = AdtCons { adt_id = Assumed Error; variant_id = Some error } in - let qualif = { id; type_args = [] } in + let qualif = { id; type_args = []; const_generic_args = [] } in let e = Qualif qualif in { e; ty } let unwrap_result_ty (ty : ty) : ty = match ty with - | Adt (Assumed Result, [ ty ]) -> ty + | Adt (Assumed Result, [ ty ], cgs) -> + assert (cgs = []); + ty | _ -> raise (Failure "not a result type") let mk_result_fail_texpression (error : texpression) (ty : ty) : texpression = let type_args = [ ty ] in - let ty = Adt (Assumed Result, type_args) in + let ty = Adt (Assumed Result, type_args, []) in let id = AdtCons { adt_id = Assumed Result; variant_id = Some result_fail_id } in - let qualif = { id; type_args } in + let qualif = { id; type_args; const_generic_args = [] } in let cons_e = Qualif qualif in let cons_ty = mk_arrow error.ty ty in let cons = { e = cons_e; ty = cons_ty } in @@ -483,11 +501,11 @@ let mk_result_fail_texpression_with_error_id (error : VariantId.id) (ty : ty) : let mk_result_return_texpression (v : texpression) : texpression = let type_args = [ v.ty ] in - let ty = Adt (Assumed Result, type_args) in + let ty = Adt (Assumed Result, type_args, []) in let id = AdtCons { adt_id = Assumed Result; variant_id = Some result_return_id } in - let qualif = { id; type_args } in + let qualif = { id; type_args; const_generic_args = [] } in let cons_e = Qualif qualif in let cons_ty = mk_arrow v.ty ty in let cons = { e = cons_e; ty = cons_ty } in @@ -496,7 +514,7 @@ let mk_result_return_texpression (v : texpression) : texpression = (** Create a [Fail err] pattern which captures the error *) let mk_result_fail_pattern (error_pat : pattern) (ty : ty) : typed_pattern = let error_pat : typed_pattern = { value = error_pat; ty = mk_error_ty } in - let ty = Adt (Assumed Result, [ ty ]) in + let ty = Adt (Assumed Result, [ ty ], []) in let value = PatAdt { variant_id = Some result_fail_id; field_values = [ error_pat ] } in @@ -508,7 +526,7 @@ let mk_result_fail_pattern_ignore_error (ty : ty) : typed_pattern = mk_result_fail_pattern error_pat ty let mk_result_return_pattern (v : typed_pattern) : typed_pattern = - let ty = Adt (Assumed Result, [ v.ty ]) in + let ty = Adt (Assumed Result, [ v.ty ], []) in let value = PatAdt { variant_id = Some result_return_id; field_values = [ v ] } in @@ -543,11 +561,11 @@ let rec typed_pattern_to_texpression (pat : typed_pattern) : texpression option let fields_values = List.map (fun e -> Option.get e) fields in (* Retrieve the type id and the type args from the pat type (simpler this way *) - let adt_id, type_args = ty_as_adt pat.ty in + let adt_id, type_args, const_generic_args = ty_as_adt pat.ty in (* Create the constructor *) let qualif_id = AdtCons { adt_id; variant_id = av.variant_id } in - let qualif = { id = qualif_id; type_args } in + let qualif = { id = qualif_id; type_args; const_generic_args } in let cons_e = Qualif qualif in let field_tys = List.map (fun (v : texpression) -> v.ty) fields_values -- cgit v1.2.3