summaryrefslogtreecommitdiff
path: root/compiler/PureUtils.ml
diff options
context:
space:
mode:
authorSon Ho2023-08-02 11:03:59 +0200
committerSon Ho2023-08-02 11:03:59 +0200
commit9d27e2e27db06eaad7565b55366ca8734b364fca (patch)
tree7cb450a93c538d671486e1d9f40aa1258401a31e /compiler/PureUtils.ml
parent50af296306bfee9f0b127dde8abe5fb0ec1b0acb (diff)
Make progress proapagating the changes
Diffstat (limited to 'compiler/PureUtils.ml')
-rw-r--r--compiler/PureUtils.ml92
1 files changed, 55 insertions, 37 deletions
diff --git a/compiler/PureUtils.ml b/compiler/PureUtils.ml
index 88b18e89..1c8d8921 100644
--- a/compiler/PureUtils.ml
+++ b/compiler/PureUtils.ml
@@ -62,7 +62,7 @@ let dest_arrow_ty (ty : ty) : ty * ty =
| Arrow (arg_ty, ret_ty) -> (arg_ty, ret_ty)
| _ -> raise (Failure "Unreachable")
-let compute_literal_ty (cv : literal) : ty =
+let compute_literal_type (cv : literal) : literal_type =
match cv with
| PV.Scalar sv -> Integer sv.PV.int_ty
| Bool _ -> Bool
@@ -71,7 +71,7 @@ let compute_literal_ty (cv : literal) : ty =
let var_get_id (v : var) : VarId.id = v.id
let mk_typed_pattern_from_literal (cv : literal) : typed_pattern =
- let ty = compute_literal_ty cv in
+ let ty = Literal (compute_literal_type cv) in
{ value = PatConstant cv; ty }
let mk_let (monadic : bool) (lv : typed_pattern) (re : texpression)
@@ -90,11 +90,13 @@ let mk_mplace (var_id : E.VarId.id) (name : string option)
{ var_id; name; projection }
(** Type substitution *)
-let ty_substitute (tsubst : TypeVarId.id -> ty) (ty : ty) : ty =
+let ty_substitute (tsubst : TypeVarId.id -> ty)
+ (cgsubst : ConstGenericVarId.id -> const_generic) (ty : ty) : ty =
let obj =
object
inherit [_] map_ty
method! visit_TypeVar _ var_id = tsubst var_id
+ method! visit_ConstGenericVar _ var_id = cgsubst var_id
end
in
obj#visit_ty () ty
@@ -109,6 +111,10 @@ let make_type_subst (vars : type_var list) (tys : ty list) : TypeVarId.id -> ty
in
fun id -> TypeVarId.Map.find id mp
+let make_const_generic_subst (vars : const_generic_var list)
+ (cgs : const_generic list) : ConstGenericVarId.id -> const_generic =
+ Substitute.make_const_generic_subst_from_vars vars cgs
+
(** Retrieve the list of fields for the given variant of a {!type:Aeneas.Pure.type_decl}.
Raises [Invalid_argument] if the arguments are incorrect.
@@ -132,14 +138,17 @@ let type_decl_get_fields (def : type_decl)
(** Instantiate the type variables for the chosen variant in an ADT definition,
and return the list of the types of its fields *)
let type_decl_get_instantiated_fields_types (def : type_decl)
- (opt_variant_id : VariantId.id option) (types : ty list) : ty list =
+ (opt_variant_id : VariantId.id option) (types : ty list)
+ (cgs : const_generic list) : ty list =
let ty_subst = make_type_subst def.type_params types in
+ let cg_subst = make_const_generic_subst def.const_generic_params cgs in
let fields = type_decl_get_fields def opt_variant_id in
- List.map (fun f -> ty_substitute ty_subst f.field_ty) fields
+ List.map (fun f -> ty_substitute ty_subst cg_subst f.field_ty) fields
-let fun_sig_substitute (tsubst : TypeVarId.id -> ty) (sg : fun_sig) :
+let fun_sig_substitute (tsubst : TypeVarId.id -> ty)
+ (cgsubst : ConstGenericVarId.id -> const_generic) (sg : fun_sig) :
inst_fun_sig =
- let subst = ty_substitute tsubst in
+ let subst = ty_substitute tsubst cgsubst in
let inputs = List.map subst sg.inputs in
let output = subst sg.output in
let doutputs = List.map subst sg.doutputs in
@@ -181,9 +190,9 @@ let is_global (e : texpression) : bool =
let is_const (e : texpression) : bool =
match e.e with Const _ -> true | _ -> false
-let ty_as_adt (ty : ty) : type_id * ty list =
+let ty_as_adt (ty : ty) : type_id * ty list * const_generic list =
match ty with
- | Adt (id, tys) -> (id, tys)
+ | Adt (id, tys, cgs) -> (id, tys, cgs)
| _ -> raise (Failure "Unreachable")
(** Remove the external occurrences of {!Meta} *)
@@ -291,13 +300,19 @@ let opt_destruct_function_call (e : texpression) :
let opt_destruct_result (ty : ty) : ty option =
match ty with
- | Adt (Assumed Result, tys) -> Some (Collections.List.to_cons_nil tys)
+ | Adt (Assumed Result, tys, cgs) ->
+ assert (cgs = []);
+ Some (Collections.List.to_cons_nil tys)
| _ -> None
let destruct_result (ty : ty) : ty = Option.get (opt_destruct_result ty)
let opt_destruct_tuple (ty : ty) : ty list option =
- match ty with Adt (Tuple, tys) -> Some tys | _ -> None
+ match ty with
+ | Adt (Tuple, tys, cgs) ->
+ assert (cgs = []);
+ Some tys
+ | _ -> None
let mk_abs (x : typed_pattern) (e : texpression) : texpression =
let ty = Arrow (x.ty, e.ty) in
@@ -351,7 +366,7 @@ let iter_switch_body_branches (f : texpression -> unit) (sb : switch_body) :
let mk_switch (scrut : texpression) (sb : switch_body) : texpression =
(* Sanity check: the scrutinee has the proper type *)
(match sb with
- | If (_, _) -> assert (scrut.ty = Bool)
+ | If (_, _) -> assert (scrut.ty = Literal Bool)
| Match branches ->
List.iter
(fun (b : match_branch) -> assert (b.pat.ty = scrut.ty))
@@ -368,14 +383,14 @@ let mk_switch (scrut : texpression) (sb : switch_body) : texpression =
- if there is > one type: wrap them in a tuple
*)
let mk_simpl_tuple_ty (tys : ty list) : ty =
- match tys with [ ty ] -> ty | _ -> Adt (Tuple, tys)
+ match tys with [ ty ] -> ty | _ -> Adt (Tuple, tys, [])
-let mk_bool_ty : ty = Bool
-let mk_unit_ty : ty = Adt (Tuple, [])
+let mk_bool_ty : ty = Literal Bool
+let mk_unit_ty : ty = Adt (Tuple, [], [])
let mk_unit_rvalue : texpression =
let id = AdtCons { adt_id = Tuple; variant_id = None } in
- let qualif = { id; type_args = [] } in
+ let qualif = { id; type_args = []; const_generic_args = [] } in
let e = Qualif qualif in
let ty = mk_unit_ty in
{ e; ty }
@@ -415,7 +430,7 @@ let mk_simpl_tuple_pattern (vl : typed_pattern list) : typed_pattern =
| [ v ] -> v
| _ ->
let tys = List.map (fun (v : typed_pattern) -> v.ty) vl in
- let ty = Adt (Tuple, tys) in
+ let ty = Adt (Tuple, tys, []) in
let value = PatAdt { variant_id = None; field_values = vl } in
{ value; ty }
@@ -426,11 +441,11 @@ let mk_simpl_tuple_texpression (vl : texpression list) : texpression =
| _ ->
(* Compute the types of the fields, and the type of the tuple constructor *)
let tys = List.map (fun (v : texpression) -> v.ty) vl in
- let ty = Adt (Tuple, tys) in
+ let ty = Adt (Tuple, tys, []) in
let ty = mk_arrows tys ty in
(* Construct the tuple constructor qualifier *)
let id = AdtCons { adt_id = Tuple; variant_id = None } in
- let qualif = { id; type_args = tys } in
+ let qualif = { id; type_args = tys; const_generic_args = [] } in
(* Put everything together *)
let cons = { e = Qualif qualif; ty } in
mk_apps cons vl
@@ -441,36 +456,39 @@ let mk_adt_pattern (adt_ty : ty) (variant_id : VariantId.id option)
{ value; ty = adt_ty }
let ty_as_integer (t : ty) : T.integer_type =
- match t with Integer int_ty -> int_ty | _ -> raise (Failure "Unreachable")
+ match t with
+ | Literal (Integer int_ty) -> int_ty
+ | _ -> raise (Failure "Unreachable")
-(* TODO: move *)
-let type_decl_is_enum (def : T.type_decl) : bool =
- match def.kind with T.Struct _ -> false | Enum _ -> true | Opaque -> false
+let ty_as_literal (t : ty) : T.literal_type =
+ match t with Literal ty -> ty | _ -> raise (Failure "Unreachable")
-let mk_state_ty : ty = Adt (Assumed State, [])
-let mk_result_ty (ty : ty) : ty = Adt (Assumed Result, [ ty ])
-let mk_error_ty : ty = Adt (Assumed Error, [])
-let mk_fuel_ty : ty = Adt (Assumed Fuel, [])
+let mk_state_ty : ty = Adt (Assumed State, [], [])
+let mk_result_ty (ty : ty) : ty = Adt (Assumed Result, [ ty ], [])
+let mk_error_ty : ty = Adt (Assumed Error, [], [])
+let mk_fuel_ty : ty = Adt (Assumed Fuel, [], [])
let mk_error (error : VariantId.id) : texpression =
let ty = mk_error_ty in
let id = AdtCons { adt_id = Assumed Error; variant_id = Some error } in
- let qualif = { id; type_args = [] } in
+ let qualif = { id; type_args = []; const_generic_args = [] } in
let e = Qualif qualif in
{ e; ty }
let unwrap_result_ty (ty : ty) : ty =
match ty with
- | Adt (Assumed Result, [ ty ]) -> ty
+ | Adt (Assumed Result, [ ty ], cgs) ->
+ assert (cgs = []);
+ ty
| _ -> raise (Failure "not a result type")
let mk_result_fail_texpression (error : texpression) (ty : ty) : texpression =
let type_args = [ ty ] in
- let ty = Adt (Assumed Result, type_args) in
+ let ty = Adt (Assumed Result, type_args, []) in
let id =
AdtCons { adt_id = Assumed Result; variant_id = Some result_fail_id }
in
- let qualif = { id; type_args } in
+ let qualif = { id; type_args; const_generic_args = [] } in
let cons_e = Qualif qualif in
let cons_ty = mk_arrow error.ty ty in
let cons = { e = cons_e; ty = cons_ty } in
@@ -483,11 +501,11 @@ let mk_result_fail_texpression_with_error_id (error : VariantId.id) (ty : ty) :
let mk_result_return_texpression (v : texpression) : texpression =
let type_args = [ v.ty ] in
- let ty = Adt (Assumed Result, type_args) in
+ let ty = Adt (Assumed Result, type_args, []) in
let id =
AdtCons { adt_id = Assumed Result; variant_id = Some result_return_id }
in
- let qualif = { id; type_args } in
+ let qualif = { id; type_args; const_generic_args = [] } in
let cons_e = Qualif qualif in
let cons_ty = mk_arrow v.ty ty in
let cons = { e = cons_e; ty = cons_ty } in
@@ -496,7 +514,7 @@ let mk_result_return_texpression (v : texpression) : texpression =
(** Create a [Fail err] pattern which captures the error *)
let mk_result_fail_pattern (error_pat : pattern) (ty : ty) : typed_pattern =
let error_pat : typed_pattern = { value = error_pat; ty = mk_error_ty } in
- let ty = Adt (Assumed Result, [ ty ]) in
+ let ty = Adt (Assumed Result, [ ty ], []) in
let value =
PatAdt { variant_id = Some result_fail_id; field_values = [ error_pat ] }
in
@@ -508,7 +526,7 @@ let mk_result_fail_pattern_ignore_error (ty : ty) : typed_pattern =
mk_result_fail_pattern error_pat ty
let mk_result_return_pattern (v : typed_pattern) : typed_pattern =
- let ty = Adt (Assumed Result, [ v.ty ]) in
+ let ty = Adt (Assumed Result, [ v.ty ], []) in
let value =
PatAdt { variant_id = Some result_return_id; field_values = [ v ] }
in
@@ -543,11 +561,11 @@ let rec typed_pattern_to_texpression (pat : typed_pattern) : texpression option
let fields_values = List.map (fun e -> Option.get e) fields in
(* Retrieve the type id and the type args from the pat type (simpler this way *)
- let adt_id, type_args = ty_as_adt pat.ty in
+ let adt_id, type_args, const_generic_args = ty_as_adt pat.ty in
(* Create the constructor *)
let qualif_id = AdtCons { adt_id; variant_id = av.variant_id } in
- let qualif = { id = qualif_id; type_args } in
+ let qualif = { id = qualif_id; type_args; const_generic_args } in
let cons_e = Qualif qualif in
let field_tys =
List.map (fun (v : texpression) -> v.ty) fields_values