summaryrefslogtreecommitdiff
path: root/compiler
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--compiler/PureMicroPasses.ml45
1 files changed, 43 insertions, 2 deletions
diff --git a/compiler/PureMicroPasses.ml b/compiler/PureMicroPasses.ml
index 7e6ca822..7d01a622 100644
--- a/compiler/PureMicroPasses.ml
+++ b/compiler/PureMicroPasses.ml
@@ -1078,6 +1078,8 @@ let simplify_aggregates (ctx : trans_ctx) (def : fun_decl) : fun_decl =
method! visit_texpression env e =
match e.e with
| App _ -> (
+ (* TODO: we should remove this case, which dates from before the
+ introduction of [StructUpdate] *)
let app, args = destruct_apps e in
match app.e with
| Qualif
@@ -1141,6 +1143,43 @@ let simplify_aggregates (ctx : trans_ctx) (def : fun_decl) : fun_decl =
else super#visit_texpression env e
else super#visit_texpression env e
| _ -> super#visit_texpression env e)
+ | StructUpdate { struct_id; init = None; updates } ->
+ let adt_ty = e.ty in
+ (* Attempt to convert all the field updates to projections
+ of fields from an ADT with the same type *)
+ let to_var_proj ((fid, arg) : FieldId.id * texpression) :
+ var_id option =
+ match arg.e with
+ | App (proj, x) -> (
+ match (proj.e, x.e) with
+ | ( Qualif
+ {
+ id = Proj { adt_id = AdtId proj_adt_id; field_id };
+ type_args = _;
+ },
+ Var v ) ->
+ (* We check that this is the proper ADT, and the proper field *)
+ if
+ proj_adt_id = struct_id && field_id = fid
+ && x.ty = adt_ty
+ then Some v
+ else None
+ | _ -> None)
+ | _ -> None
+ in
+ let var_projs = List.map to_var_proj updates in
+ let filt_var_projs = List.filter_map (fun x -> x) var_projs in
+ if filt_var_projs = [] then super#visit_texpression env e
+ else
+ (* If all the projections are from the same variable [x], we
+ simply replace the whole expression with [x] *)
+ let x = List.hd filt_var_projs in
+ if
+ List.length filt_var_projs = List.length updates
+ && List.for_all (fun y -> y = x) filt_var_projs
+ then { e with e = Var x }
+ else (* TODO *)
+ super#visit_texpression env e
| _ -> super#visit_texpression env e
end
in
@@ -1750,9 +1789,11 @@ let apply_end_passes_to_def (ctx : trans_ctx) (def : fun_decl) : fun_decl =
Ex.:
{[
- (* type struct = { f0 : nat; f1 : nat } *)
+ (* type struct = { f0 : nat; f1 : nat; f2 : nat } *)
- Mkstruct x.f0 x.f1 ~~> x
+ Mkstruct x.f0 x.f1 x.f2 ~~> x
+ { f0 := x.f0; f1 := x.f1; f2 := x.f2 } ~~> x
+ { f0 := x.f0; f1 := x.f1; f2 := v } ~~> { x with f2 = v }
]}
*)
let def = simplify_aggregates ctx def in