summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSon Ho2022-05-15 22:34:31 +0200
committerSon Ho2022-05-15 22:34:31 +0200
commitffc93a3f4d3b29e3a6805f9882f20dd22d184939 (patch)
tree79013f1f1cb001409a1fffe526c2a27bd0ebc28b
parenta25d820b6eb02f573ad2c274a35e3496a9dacd40 (diff)
Add a pass to cleanup the deconstructed ADTs and fix a small issue
-rw-r--r--fstar/Primitives.fst2
-rw-r--r--src/ExtractToFStar.ml6
-rw-r--r--src/PrePasses.ml1
-rw-r--r--src/PureMicroPasses.ml117
4 files changed, 113 insertions, 13 deletions
diff --git a/fstar/Primitives.fst b/fstar/Primitives.fst
index f73c8c09..fe351f3a 100644
--- a/fstar/Primitives.fst
+++ b/fstar/Primitives.fst
@@ -146,7 +146,7 @@ let scalar_mul (#ty : scalar_ty) (x : scalar ty) (y : scalar ty) : result (scala
mk_scalar ty (x * y)
(** Cast an integer from a [src_ty] to a [tgt_ty] *)
-let scalar_cast (#src_ty : scalar_ty) (tgt_ty : scalar_ty) (x : scalar src_ty) : result (scalar tgt_ty) =
+let scalar_cast (src_ty : scalar_ty) (tgt_ty : scalar_ty) (x : scalar src_ty) : result (scalar tgt_ty) =
mk_scalar tgt_ty x
/// The scalar types
diff --git a/src/ExtractToFStar.ml b/src/ExtractToFStar.ml
index 84e447a8..8b37b96a 100644
--- a/src/ExtractToFStar.ml
+++ b/src/ExtractToFStar.ml
@@ -156,13 +156,17 @@ let fstar_extract_unop (extract_expr : bool -> texpression -> unit)
F.pp_print_space fmt ();
extract_expr true arg;
if inside then F.pp_print_string fmt ")"
- | Cast (_src, tgt) ->
+ | Cast (src, tgt) ->
(* The source type is an implicit parameter *)
if inside then F.pp_print_string fmt "(";
F.pp_print_string fmt "scalar_cast";
F.pp_print_space fmt ();
F.pp_print_string fmt
(StringUtils.capitalize_first_letter
+ (PrintPure.integer_type_to_string src));
+ F.pp_print_space fmt ();
+ F.pp_print_string fmt
+ (StringUtils.capitalize_first_letter
(PrintPure.integer_type_to_string tgt));
F.pp_print_space fmt ();
extract_expr true arg;
diff --git a/src/PrePasses.ml b/src/PrePasses.ml
index dda3c867..c9d496ea 100644
--- a/src/PrePasses.ml
+++ b/src/PrePasses.ml
@@ -23,6 +23,7 @@ let log = L.pre_passes_log
*x = move ...;
```
+ TODO: this is not necessary anymore
*)
let filter_drop_assigns (f : A.fun_decl) : A.fun_decl =
(* The visitor *)
diff --git a/src/PureMicroPasses.ml b/src/PureMicroPasses.ml
index 0c371420..826283ae 100644
--- a/src/PureMicroPasses.ml
+++ b/src/PureMicroPasses.ml
@@ -85,9 +85,7 @@ let get_body_min_var_counter (body : fun_body) : VarId.generator =
let obj =
object
inherit [_] reduce_expression
-
method zero _ = min_input_id
-
method plus id0 id1 _ = VarId.max (id0 ()) (id1 ())
(* Get the maximum *)
@@ -294,7 +292,6 @@ let compute_pretty_names (def : fun_decl) : fun_decl =
let obj =
object
inherit [_] map_typed_pattern
-
method! visit_PatVar _ v mp = PatVar (update_var ctx v mp, mp)
end
in
@@ -356,9 +353,7 @@ let compute_pretty_names (def : fun_decl) : fun_decl =
let obj =
object (self)
inherit [_] reduce_typed_pattern
-
method zero _ = empty_ctx
-
method plus ctx0 ctx1 _ = merge_ctxs (ctx0 ()) (ctx1 ())
method! visit_PatVar _ v mp () =
@@ -742,9 +737,7 @@ let expression_contains_child_call_in_all_paths (ctx : trans_ctx)
let visitor =
object (self)
inherit [_] reduce_expression
-
method zero _ = false
-
method plus b0 b1 _ = b0 () && b1 ()
method! visit_texpression env e =
@@ -806,9 +799,7 @@ let filter_useless (filter_monadic_calls : bool) (ctx : trans_ctx)
let lv_visitor =
object
inherit [_] mapreduce_typed_pattern
-
method zero _ = true
-
method plus b0 b1 _ = b0 () && b1 ()
method! visit_PatVar env v mp =
@@ -830,9 +821,7 @@ let filter_useless (filter_monadic_calls : bool) (ctx : trans_ctx)
let expr_visitor =
object (self)
inherit [_] mapreduce_expression as super
-
method zero _ = VarId.Set.empty
-
method plus s0 s1 _ = VarId.Set.union (s0 ()) (s1 ())
method! visit_Var _ vid = (Var vid, fun _ -> VarId.Set.singleton vid)
@@ -910,6 +899,100 @@ let filter_useless (filter_monadic_calls : bool) (ctx : trans_ctx)
let body = { body with body = body_exp; inputs_lvs } in
{ def with body = Some body }
+(** Simplify the aggregated ADTs.
+ Ex.:
+ ```
+ type struct = { f0 : nat; f1 : nat }
+
+ Mkstruct x.f0 x.f1 ~~> x
+ ```
+
+ TODO: introduce a notation for { x with field = ... }, and use it.
+ *)
+let simplify_aggregates (ctx : trans_ctx) (def : fun_decl) : fun_decl =
+ let expr_visitor =
+ object
+ inherit [_] map_expression as super
+
+ (* Look for a type constructor applied to arguments *)
+ method! visit_texpression env e =
+ match e.e with
+ | App _ -> (
+ let app, args = destruct_apps e in
+ match app.e with
+ | Qualif
+ {
+ id = AdtCons { adt_id = AdtId adt_id; variant_id = None };
+ type_args;
+ } ->
+ (* This is a struct *)
+ (* Retrieve the definiton, to find how many fields there are *)
+ let adt_decl =
+ TypeDeclId.Map.find adt_id ctx.type_context.type_decls
+ in
+ let fields =
+ match adt_decl.kind with
+ | Enum _ | Opaque -> raise (Failure "Unreachable")
+ | Struct fields -> fields
+ in
+ let num_fields = List.length fields in
+ (* In order to simplify, there must be as many arguments as
+ * there are fields *)
+ assert (num_fields > 0);
+ if num_fields = List.length args then
+ (* We now need to check that all the arguments are of the form:
+ * `x.field` for some variable `x`, and where the projection
+ * is for the proper ADT *)
+ let to_var_proj (i : int) (arg : texpression) :
+ (ty list * var_id) option =
+ match arg.e with
+ | App (proj, x) -> (
+ match (proj.e, x.e) with
+ | ( Qualif
+ {
+ id =
+ Proj { adt_id = AdtId proj_adt_id; field_id };
+ type_args = proj_type_args;
+ },
+ Var v ) ->
+ (* We check that this is the proper ADT, and the proper field *)
+ if
+ proj_adt_id = adt_id
+ && FieldId.to_int field_id = i
+ then Some (proj_type_args, v)
+ else None
+ | _ -> None)
+ | _ -> None
+ in
+ let args = List.mapi to_var_proj args in
+ let args = List.filter_map (fun x -> x) args in
+ (* Check that all the arguments are of the expected form *)
+ if List.length args = num_fields then
+ (* Check that this is the same variable we project from -
+ * note that we checked above that there is at least one field *)
+ let (_, x), end_args = Collections.List.pop args in
+ if List.for_all (fun (_, y) -> y = x) end_args then (
+ (* We can substitute *)
+ (* Sanity check: all types correct *)
+ assert (
+ List.for_all (fun (tys, _) -> tys = type_args) args);
+ { e with e = Var x })
+ else super#visit_texpression env e
+ else super#visit_texpression env e
+ else super#visit_texpression env e
+ | _ -> super#visit_texpression env e)
+ | _ -> super#visit_texpression env e
+ end
+ in
+ match def.body with
+ | None -> def
+ | Some body ->
+ (* Visit the body *)
+ let body_exp = expr_visitor#visit_texpression () body.body in
+ (* Return *)
+ let body = { body with body = body_exp } in
+ { def with body = Some body }
+
(** Return `None` if the function is a backward function with no outputs (so
that we eliminate the definition which is useless).
@@ -1222,6 +1305,18 @@ let apply_passes_to_def (config : config) (ctx : trans_ctx) (def : fun_decl) :
log#ldebug
(lazy ("filter_useless:\n\n" ^ fun_decl_to_string ctx def ^ "\n"));
+ (* Simplify the aggregated ADTs.
+ * Ex.:
+ * ```
+ * type struct = { f0 : nat; f1 : nat }
+ *
+ * Mkstruct x.f0 x.f1 ~~> x
+ * ```
+ *)
+ let def = simplify_aggregates ctx def in
+ log#ldebug
+ (lazy ("simplify_aggregates:\n\n" ^ fun_decl_to_string ctx def ^ "\n"));
+
(* Decompose the monadic let-bindings - F* specific
* TODO: remove? *)
let def =