From ffc93a3f4d3b29e3a6805f9882f20dd22d184939 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Sun, 15 May 2022 22:34:31 +0200 Subject: Add a pass to cleanup the deconstructed ADTs and fix a small issue --- fstar/Primitives.fst | 2 +- src/ExtractToFStar.ml | 6 ++- src/PrePasses.ml | 1 + src/PureMicroPasses.ml | 117 ++++++++++++++++++++++++++++++++++++++++++++----- 4 files changed, 113 insertions(+), 13 deletions(-) diff --git a/fstar/Primitives.fst b/fstar/Primitives.fst index f73c8c09..fe351f3a 100644 --- a/fstar/Primitives.fst +++ b/fstar/Primitives.fst @@ -146,7 +146,7 @@ let scalar_mul (#ty : scalar_ty) (x : scalar ty) (y : scalar ty) : result (scala mk_scalar ty (x * y) (** Cast an integer from a [src_ty] to a [tgt_ty] *) -let scalar_cast (#src_ty : scalar_ty) (tgt_ty : scalar_ty) (x : scalar src_ty) : result (scalar tgt_ty) = +let scalar_cast (src_ty : scalar_ty) (tgt_ty : scalar_ty) (x : scalar src_ty) : result (scalar tgt_ty) = mk_scalar tgt_ty x /// The scalar types diff --git a/src/ExtractToFStar.ml b/src/ExtractToFStar.ml index 84e447a8..8b37b96a 100644 --- a/src/ExtractToFStar.ml +++ b/src/ExtractToFStar.ml @@ -156,11 +156,15 @@ let fstar_extract_unop (extract_expr : bool -> texpression -> unit) F.pp_print_space fmt (); extract_expr true arg; if inside then F.pp_print_string fmt ")" - | Cast (_src, tgt) -> + | Cast (src, tgt) -> (* The source type is an implicit parameter *) if inside then F.pp_print_string fmt "("; F.pp_print_string fmt "scalar_cast"; F.pp_print_space fmt (); + F.pp_print_string fmt + (StringUtils.capitalize_first_letter + (PrintPure.integer_type_to_string src)); + F.pp_print_space fmt (); F.pp_print_string fmt (StringUtils.capitalize_first_letter (PrintPure.integer_type_to_string tgt)); diff --git a/src/PrePasses.ml b/src/PrePasses.ml index dda3c867..c9d496ea 100644 --- a/src/PrePasses.ml +++ b/src/PrePasses.ml @@ -23,6 +23,7 @@ let log = L.pre_passes_log *x = move ...; ``` + TODO: this is not necessary anymore *) let filter_drop_assigns (f : A.fun_decl) : A.fun_decl = (* The visitor *) diff --git a/src/PureMicroPasses.ml b/src/PureMicroPasses.ml index 0c371420..826283ae 100644 --- a/src/PureMicroPasses.ml +++ b/src/PureMicroPasses.ml @@ -85,9 +85,7 @@ let get_body_min_var_counter (body : fun_body) : VarId.generator = let obj = object inherit [_] reduce_expression - method zero _ = min_input_id - method plus id0 id1 _ = VarId.max (id0 ()) (id1 ()) (* Get the maximum *) @@ -294,7 +292,6 @@ let compute_pretty_names (def : fun_decl) : fun_decl = let obj = object inherit [_] map_typed_pattern - method! visit_PatVar _ v mp = PatVar (update_var ctx v mp, mp) end in @@ -356,9 +353,7 @@ let compute_pretty_names (def : fun_decl) : fun_decl = let obj = object (self) inherit [_] reduce_typed_pattern - method zero _ = empty_ctx - method plus ctx0 ctx1 _ = merge_ctxs (ctx0 ()) (ctx1 ()) method! visit_PatVar _ v mp () = @@ -742,9 +737,7 @@ let expression_contains_child_call_in_all_paths (ctx : trans_ctx) let visitor = object (self) inherit [_] reduce_expression - method zero _ = false - method plus b0 b1 _ = b0 () && b1 () method! visit_texpression env e = @@ -806,9 +799,7 @@ let filter_useless (filter_monadic_calls : bool) (ctx : trans_ctx) let lv_visitor = object inherit [_] mapreduce_typed_pattern - method zero _ = true - method plus b0 b1 _ = b0 () && b1 () method! visit_PatVar env v mp = @@ -830,9 +821,7 @@ let filter_useless (filter_monadic_calls : bool) (ctx : trans_ctx) let expr_visitor = object (self) inherit [_] mapreduce_expression as super - method zero _ = VarId.Set.empty - method plus s0 s1 _ = VarId.Set.union (s0 ()) (s1 ()) method! visit_Var _ vid = (Var vid, fun _ -> VarId.Set.singleton vid) @@ -910,6 +899,100 @@ let filter_useless (filter_monadic_calls : bool) (ctx : trans_ctx) let body = { body with body = body_exp; inputs_lvs } in { def with body = Some body } +(** Simplify the aggregated ADTs. + Ex.: + ``` + type struct = { f0 : nat; f1 : nat } + + Mkstruct x.f0 x.f1 ~~> x + ``` + + TODO: introduce a notation for { x with field = ... }, and use it. + *) +let simplify_aggregates (ctx : trans_ctx) (def : fun_decl) : fun_decl = + let expr_visitor = + object + inherit [_] map_expression as super + + (* Look for a type constructor applied to arguments *) + method! visit_texpression env e = + match e.e with + | App _ -> ( + let app, args = destruct_apps e in + match app.e with + | Qualif + { + id = AdtCons { adt_id = AdtId adt_id; variant_id = None }; + type_args; + } -> + (* This is a struct *) + (* Retrieve the definiton, to find how many fields there are *) + let adt_decl = + TypeDeclId.Map.find adt_id ctx.type_context.type_decls + in + let fields = + match adt_decl.kind with + | Enum _ | Opaque -> raise (Failure "Unreachable") + | Struct fields -> fields + in + let num_fields = List.length fields in + (* In order to simplify, there must be as many arguments as + * there are fields *) + assert (num_fields > 0); + if num_fields = List.length args then + (* We now need to check that all the arguments are of the form: + * `x.field` for some variable `x`, and where the projection + * is for the proper ADT *) + let to_var_proj (i : int) (arg : texpression) : + (ty list * var_id) option = + match arg.e with + | App (proj, x) -> ( + match (proj.e, x.e) with + | ( Qualif + { + id = + Proj { adt_id = AdtId proj_adt_id; field_id }; + type_args = proj_type_args; + }, + Var v ) -> + (* We check that this is the proper ADT, and the proper field *) + if + proj_adt_id = adt_id + && FieldId.to_int field_id = i + then Some (proj_type_args, v) + else None + | _ -> None) + | _ -> None + in + let args = List.mapi to_var_proj args in + let args = List.filter_map (fun x -> x) args in + (* Check that all the arguments are of the expected form *) + if List.length args = num_fields then + (* Check that this is the same variable we project from - + * note that we checked above that there is at least one field *) + let (_, x), end_args = Collections.List.pop args in + if List.for_all (fun (_, y) -> y = x) end_args then ( + (* We can substitute *) + (* Sanity check: all types correct *) + assert ( + List.for_all (fun (tys, _) -> tys = type_args) args); + { e with e = Var x }) + else super#visit_texpression env e + else super#visit_texpression env e + else super#visit_texpression env e + | _ -> super#visit_texpression env e) + | _ -> super#visit_texpression env e + end + in + match def.body with + | None -> def + | Some body -> + (* Visit the body *) + let body_exp = expr_visitor#visit_texpression () body.body in + (* Return *) + let body = { body with body = body_exp } in + { def with body = Some body } + (** Return `None` if the function is a backward function with no outputs (so that we eliminate the definition which is useless). @@ -1222,6 +1305,18 @@ let apply_passes_to_def (config : config) (ctx : trans_ctx) (def : fun_decl) : log#ldebug (lazy ("filter_useless:\n\n" ^ fun_decl_to_string ctx def ^ "\n")); + (* Simplify the aggregated ADTs. + * Ex.: + * ``` + * type struct = { f0 : nat; f1 : nat } + * + * Mkstruct x.f0 x.f1 ~~> x + * ``` + *) + let def = simplify_aggregates ctx def in + log#ldebug + (lazy ("simplify_aggregates:\n\n" ^ fun_decl_to_string ctx def ^ "\n")); + (* Decompose the monadic let-bindings - F* specific * TODO: remove? *) let def = -- cgit v1.2.3