summaryrefslogtreecommitdiff
path: root/compiler/PureMicroPasses.ml
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--compiler/PureMicroPasses.ml122
1 files changed, 78 insertions, 44 deletions
diff --git a/compiler/PureMicroPasses.ml b/compiler/PureMicroPasses.ml
index d0741b29..959ec1c8 100644
--- a/compiler/PureMicroPasses.ml
+++ b/compiler/PureMicroPasses.ml
@@ -563,12 +563,13 @@ let remove_meta (def : fun_decl) : fun_decl =
This micro-pass turns those into expressions which use structure syntax:
{[
- {
- f0 := x0;
- ...
- fn := xn;
- }
+ type struct = { f0 : nat; f1 : nat; f2 : nat }
+
+ Mkstruct x.f0 x.f1 y ~~> { x with f2 = y }
]}
+
+ Note however that we do not apply this transformation if the
+ structure is to be extracted as a tuple.
*)
let intro_struct_updates (ctx : trans_ctx) (def : fun_decl) : fun_decl =
let obj =
@@ -592,37 +593,44 @@ let intro_struct_updates (ctx : trans_ctx) (def : fun_decl) : fun_decl =
} ->
(* Lookup the def *)
let decl = TypeDeclId.Map.find adt_id ctx.type_ctx.type_decls in
- (* Check that there are as many arguments as there are fields - note
- that the def should have a body (otherwise we couldn't use the
- constructor) *)
- let fields = TypesUtils.type_decl_get_fields decl None in
- if List.length fields = List.length args then
- (* Check if the definition is recursive *)
- let is_rec =
- match
- TypeDeclId.Map.find adt_id ctx.type_ctx.type_decls_groups
- with
- | NonRecGroup _ -> false
- | RecGroup _ -> true
- in
- (* Convert, if possible - note that for now for Lean and Coq
- we don't support the structure syntax on recursive structures *)
- if
- (!Config.backend <> Lean && !Config.backend <> Coq)
- || not is_rec
- then
- let struct_id = TAdtId adt_id in
- let init = None in
- let updates =
- FieldId.mapi
- (fun fid fe -> (fid, self#visit_texpression env fe))
- args
+ (* Check if the def will be extracted as a tuple *)
+ if
+ TypesUtils.type_decl_from_decl_id_is_tuple_struct
+ ctx.type_ctx.type_infos adt_id
+ then ignore ()
+ else
+ (* Check that there are as many arguments as there are fields - note
+ that the def should have a body (otherwise we couldn't use the
+ constructor) *)
+ let fields = TypesUtils.type_decl_get_fields decl None in
+ if List.length fields = List.length args then
+ (* Check if the definition is recursive *)
+ let is_rec =
+ match
+ TypeDeclId.Map.find adt_id
+ ctx.type_ctx.type_decls_groups
+ with
+ | NonRecGroup _ -> false
+ | RecGroup _ -> true
in
- let ne = { struct_id; init; updates } in
- let nty = e.ty in
- { e = StructUpdate ne; ty = nty }
+ (* Convert, if possible - note that for now for Lean and Coq
+ we don't support the structure syntax on recursive structures *)
+ if
+ (!Config.backend <> Lean && !Config.backend <> Coq)
+ || not is_rec
+ then
+ let struct_id = TAdtId adt_id in
+ let init = None in
+ let updates =
+ FieldId.mapi
+ (fun fid fe -> (fid, self#visit_texpression env fe))
+ args
+ in
+ let ne = { struct_id; init; updates } in
+ let nty = e.ty in
+ { e = StructUpdate ne; ty = nty }
+ else ignore ()
else ignore ()
- else ignore ()
| _ -> ignore ())
| _ -> super#visit_texpression env e
end
@@ -659,8 +667,8 @@ let intro_struct_updates (ctx : trans_ctx) (def : fun_decl) : fun_decl =
leave the let-bindings where they are, and eliminated them in a subsequent
pass (if they are useless).
*)
-let inline_useless_var_reassignments (inline_named : bool) (inline_pure : bool)
- (def : fun_decl) : fun_decl =
+let inline_useless_var_reassignments (ctx : trans_ctx) (inline_named : bool)
+ (inline_pure : bool) (def : fun_decl) : fun_decl =
let obj =
object (self)
inherit [_] map_expression as super
@@ -669,9 +677,12 @@ let inline_useless_var_reassignments (inline_named : bool) (inline_pure : bool)
the substitution map while doing so *)
method! visit_Let (env : texpression VarId.Map.t) monadic lv re e =
(* In order to filter, we need to check first that:
- * - the let-binding is not monadic
- * - the left-value is a variable
- *)
+ - the let-binding is not monadic
+ - the left-value is a variable
+
+ We also inline if the binding decomposes a structure that is to be
+ extracted as a tuple, and the right value is a variable.
+ *)
match (monadic, lv.value) with
| false, PatVar (lv_var, _) ->
(* We can filter if: *)
@@ -717,6 +728,31 @@ let inline_useless_var_reassignments (inline_named : bool) (inline_pure : bool)
let e = self#visit_texpression env e in
(* Reconstruct the [let], only if the binding is not filtered *)
if filter then e.e else Let (monadic, lv, re, e)
+ | ( false,
+ PatAdt
+ {
+ variant_id = None;
+ field_values = [ { value = PatVar (lv_var, _); ty = _ } ];
+ } ) ->
+ (* Second case: we deconstruct a structure with one field that we will
+ extract as tuple. *)
+ let adt_id, _ = PureUtils.ty_as_adt re.ty in
+ (* Update the rhs (we may perform substitutions inside, and it is
+ * better to do them *before* we inline it *)
+ let re = self#visit_texpression env re in
+ if
+ PureUtils.is_var re
+ && type_decl_from_type_id_is_tuple_struct ctx.type_ctx.type_infos
+ adt_id
+ then
+ (* Update the substitution environment *)
+ let env = VarId.Map.add lv_var.id re env in
+ (* Update the next expression *)
+ let e = self#visit_texpression env e in
+ (* We filter the [let], and thus do not reconstruct it *)
+ e.e
+ else (* Nothing to do *)
+ super#visit_Let env monadic lv re e
| _ -> super#visit_Let env monadic lv re e
(** Substitute the variables *)
@@ -1069,12 +1105,10 @@ let simplify_let_then_return _ctx def =
(** Simplify the aggregated ADTs.
Ex.:
{[
- type struct = { f0 : nat; f1 : nat }
+ type struct = { f0 : nat; f1 : nat; f2 : nat }
- Mkstruct x.f0 x.f1 ~~> x
+ Mkstruct x.f0 x.f1 x.f2 ~~> x
]}
-
- TODO: introduce a notation for [{ x with field = ... }], and use it.
*)
let simplify_aggregates (ctx : trans_ctx) (def : fun_decl) : fun_decl =
let expr_visitor =
@@ -1786,7 +1820,7 @@ let apply_end_passes_to_def (ctx : trans_ctx) (def : fun_decl) : fun_decl =
let inline_named_vars = true in
let inline_pure = true in
let def =
- inline_useless_var_reassignments inline_named_vars inline_pure def
+ inline_useless_var_reassignments ctx inline_named_vars inline_pure def
in
log#ldebug
(lazy