diff options
Diffstat (limited to 'compiler')
-rw-r--r-- | compiler/PureMicroPasses.ml | 45 |
1 files changed, 43 insertions, 2 deletions
diff --git a/compiler/PureMicroPasses.ml b/compiler/PureMicroPasses.ml index 7e6ca822..7d01a622 100644 --- a/compiler/PureMicroPasses.ml +++ b/compiler/PureMicroPasses.ml @@ -1078,6 +1078,8 @@ let simplify_aggregates (ctx : trans_ctx) (def : fun_decl) : fun_decl = method! visit_texpression env e = match e.e with | App _ -> ( + (* TODO: we should remove this case, which dates from before the + introduction of [StructUpdate] *) let app, args = destruct_apps e in match app.e with | Qualif @@ -1141,6 +1143,43 @@ let simplify_aggregates (ctx : trans_ctx) (def : fun_decl) : fun_decl = else super#visit_texpression env e else super#visit_texpression env e | _ -> super#visit_texpression env e) + | StructUpdate { struct_id; init = None; updates } -> + let adt_ty = e.ty in + (* Attempt to convert all the field updates to projections + of fields from an ADT with the same type *) + let to_var_proj ((fid, arg) : FieldId.id * texpression) : + var_id option = + match arg.e with + | App (proj, x) -> ( + match (proj.e, x.e) with + | ( Qualif + { + id = Proj { adt_id = AdtId proj_adt_id; field_id }; + type_args = _; + }, + Var v ) -> + (* We check that this is the proper ADT, and the proper field *) + if + proj_adt_id = struct_id && field_id = fid + && x.ty = adt_ty + then Some v + else None + | _ -> None) + | _ -> None + in + let var_projs = List.map to_var_proj updates in + let filt_var_projs = List.filter_map (fun x -> x) var_projs in + if filt_var_projs = [] then super#visit_texpression env e + else + (* If all the projections are from the same variable [x], we + simply replace the whole expression with [x] *) + let x = List.hd filt_var_projs in + if + List.length filt_var_projs = List.length updates + && List.for_all (fun y -> y = x) filt_var_projs + then { e with e = Var x } + else (* TODO *) + super#visit_texpression env e | _ -> super#visit_texpression env e end in @@ -1750,9 +1789,11 @@ let apply_end_passes_to_def (ctx : trans_ctx) (def : fun_decl) : fun_decl = Ex.: {[ - (* type struct = { f0 : nat; f1 : nat } *) + (* type struct = { f0 : nat; f1 : nat; f2 : nat } *) - Mkstruct x.f0 x.f1 ~~> x + Mkstruct x.f0 x.f1 x.f2 ~~> x + { f0 := x.f0; f1 := x.f1; f2 := x.f2 } ~~> x + { f0 := x.f0; f1 := x.f1; f2 := v } ~~> { x with f2 = v } ]} *) let def = simplify_aggregates ctx def in |