summaryrefslogtreecommitdiff
path: root/compiler/PureMicroPasses.ml
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/PureMicroPasses.ml')
-rw-r--r--compiler/PureMicroPasses.ml53
1 files changed, 41 insertions, 12 deletions
diff --git a/compiler/PureMicroPasses.ml b/compiler/PureMicroPasses.ml
index 74f3c576..b6025df4 100644
--- a/compiler/PureMicroPasses.ml
+++ b/compiler/PureMicroPasses.ml
@@ -585,6 +585,7 @@ let intro_struct_updates (ctx : trans_ctx) (def : fun_decl) : fun_decl =
{
id = AdtCons { adt_id = AdtId adt_id; variant_id = None };
type_args = _;
+ const_generic_args = _;
} ->
(* Lookup the def *)
let decl =
@@ -610,7 +611,7 @@ let intro_struct_updates (ctx : trans_ctx) (def : fun_decl) : fun_decl =
(!Config.backend <> Lean && !Config.backend <> Coq)
|| not is_rec
then
- let struct_id = adt_id in
+ let struct_id = AdtId adt_id in
let init = None in
let updates =
FieldId.mapi
@@ -1086,6 +1087,7 @@ let simplify_aggregates (ctx : trans_ctx) (def : fun_decl) : fun_decl =
{
id = AdtCons { adt_id = AdtId adt_id; variant_id = None };
type_args;
+ const_generic_args;
} ->
(* This is a struct *)
(* Retrieve the definiton, to find how many fields there are *)
@@ -1106,7 +1108,7 @@ let simplify_aggregates (ctx : trans_ctx) (def : fun_decl) : fun_decl =
* [x.field] for some variable [x], and where the projection
* is for the proper ADT *)
let to_var_proj (i : int) (arg : texpression) :
- (ty list * var_id) option =
+ (ty list * const_generic list * var_id) option =
match arg.e with
| App (proj, x) -> (
match (proj.e, x.e) with
@@ -1115,13 +1117,15 @@ let simplify_aggregates (ctx : trans_ctx) (def : fun_decl) : fun_decl =
id =
Proj { adt_id = AdtId proj_adt_id; field_id };
type_args = proj_type_args;
+ const_generic_args = proj_const_generic_args;
},
Var v ) ->
(* We check that this is the proper ADT, and the proper field *)
if
proj_adt_id = adt_id
&& FieldId.to_int field_id = i
- then Some (proj_type_args, v)
+ then
+ Some (proj_type_args, proj_const_generic_args, v)
else None
| _ -> None)
| _ -> None
@@ -1132,12 +1136,15 @@ let simplify_aggregates (ctx : trans_ctx) (def : fun_decl) : fun_decl =
if List.length args = num_fields then
(* Check that this is the same variable we project from -
* note that we checked above that there is at least one field *)
- let (_, x), end_args = Collections.List.pop args in
- if List.for_all (fun (_, y) -> y = x) end_args then (
+ let (_, _, x), end_args = Collections.List.pop args in
+ if List.for_all (fun (_, _, y) -> y = x) end_args then (
(* We can substitute *)
(* Sanity check: all types correct *)
assert (
- List.for_all (fun (tys, _) -> tys = type_args) args);
+ List.for_all
+ (fun (tys, cgs, _) ->
+ tys = type_args && cgs = const_generic_args)
+ args);
{ e with e = Var x })
else super#visit_texpression env e
else super#visit_texpression env e
@@ -1156,12 +1163,13 @@ let simplify_aggregates (ctx : trans_ctx) (def : fun_decl) : fun_decl =
{
id = Proj { adt_id = AdtId proj_adt_id; field_id };
type_args = _;
+ const_generic_args = _;
},
Var v ) ->
(* We check that this is the proper ADT, and the proper field *)
if
- proj_adt_id = struct_id && field_id = fid
- && x.ty = adt_ty
+ AdtId proj_adt_id = struct_id
+ && field_id = fid && x.ty = adt_ty
then Some v
else None
| _ -> None)
@@ -1354,6 +1362,7 @@ let decompose_loops (def : fun_decl) : fun_decl * fun_decl list =
let loop_sig =
{
type_params = fun_sig.type_params;
+ const_generic_params = fun_sig.const_generic_params;
inputs = inputs_tys;
output;
doutputs;
@@ -1554,8 +1563,12 @@ let eliminate_box_functions (_ctx : trans_ctx) (def : fun_decl) : fun_decl =
| A.BoxFree, _ ->
assert (args = []);
mk_unit_rvalue
- | ( ( A.Replace | A.VecNew | A.VecPush | A.VecInsert | A.VecLen
- | A.VecIndex | A.VecIndexMut ),
+ | ( ( A.Replace | VecNew | VecPush | VecInsert | VecLen
+ | VecIndex | VecIndexMut | ArraySubsliceShared
+ | ArraySubsliceMut | SliceIndexShared | SliceIndexMut
+ | SliceSubsliceShared | SliceSubsliceMut | ArrayIndexShared
+ | ArrayIndexMut | ArrayToSliceShared | ArrayToSliceMut
+ | SliceLen ),
_ ) ->
super#visit_texpression env e)
| _ -> super#visit_texpression env e)
@@ -2130,7 +2143,14 @@ let filter_loop_inputs (transl : (bool * pure_fun_translation) list) :
let num_filtered =
List.length (List.filter (fun b -> not b) used_info)
in
- let { type_params; inputs; output; doutputs; info } =
+ let {
+ type_params;
+ const_generic_params;
+ inputs;
+ output;
+ doutputs;
+ info;
+ } =
decl.signature
in
let {
@@ -2158,7 +2178,16 @@ let filter_loop_inputs (transl : (bool * pure_fun_translation) list) :
effect_info;
}
in
- let signature = { type_params; inputs; output; doutputs; info } in
+ let signature =
+ {
+ type_params;
+ const_generic_params;
+ inputs;
+ output;
+ doutputs;
+ info;
+ }
+ in
{ decl with signature }
in