summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSon Ho2024-06-17 07:14:52 +0200
committerSon Ho2024-06-17 07:14:52 +0200
commitf4739fba4be95818ca01776837c8d610e443a45b (patch)
tree104aca04e404e3bc3f82cc5fd9a4f21f59789b53
parent68e623b037a07c986f1a84e21196b9eee29a0d8e (diff)
Automatically add a @[reducible] attribute to some generated functions
-rw-r--r--compiler/Extract.ml4
-rw-r--r--compiler/Pure.ml7
-rw-r--r--compiler/PureMicroPasses.ml52
-rw-r--r--compiler/SymbolicToPure.ml6
4 files changed, 66 insertions, 3 deletions
diff --git a/compiler/Extract.ml b/compiler/Extract.ml
index 4acf3f99..b1adb936 100644
--- a/compiler/Extract.ml
+++ b/compiler/Extract.ml
@@ -1448,6 +1448,10 @@ let extract_fun_decl_gen (ctx : extraction_ctx) (fmt : F.formatter)
(* Open two boxes for the definition, so that whenever possible it gets printed on
* one line and indents are correct *)
F.pp_open_hvbox fmt 0;
+ (* Print the attributes *)
+ if def.backend_attributes.reducible && backend () = Lean then (
+ F.pp_print_string fmt "@[reducible]";
+ F.pp_print_space fmt ());
F.pp_open_vbox fmt ctx.indent_incr;
(* For HOL4: we may need to put parentheses around the definition *)
let parenthesize = backend () = HOL4 && decl_is_not_last_from_group kind in
diff --git a/compiler/Pure.ml b/compiler/Pure.ml
index d07b8cfa..f7445575 100644
--- a/compiler/Pure.ml
+++ b/compiler/Pure.ml
@@ -1077,11 +1077,18 @@ type fun_body = {
type item_kind = A.item_kind [@@deriving show]
+(** Attributes to add to the generated code *)
+type backend_attributes = {
+ reducible : bool; (** Lean "reducible" attribute *)
+}
+[@@deriving show]
+
type fun_decl = {
def_id : FunDeclId.id;
is_local : bool;
span : span;
kind : item_kind;
+ backend_attributes : backend_attributes;
num_loops : int;
(** The number of loops in the parent forward function (basically the number
of loops appearing in the original Rust functions, unless some loops are
diff --git a/compiler/PureMicroPasses.ml b/compiler/PureMicroPasses.ml
index b0cba250..8b95f729 100644
--- a/compiler/PureMicroPasses.ml
+++ b/compiler/PureMicroPasses.ml
@@ -1498,6 +1498,7 @@ let decompose_loops (_ctx : trans_ctx) (def : fun_decl) :
is_local = def.is_local;
span = loop.span;
kind = def.kind;
+ backend_attributes = def.backend_attributes;
num_loops;
loop_id = Some loop.loop_id;
llbc_name = def.llbc_name;
@@ -2277,6 +2278,52 @@ let filter_loop_inputs (ctx : trans_ctx) (transl : pure_fun_translation list) :
(* Return *)
transl
+(** Update the [reducible] attribute.
+
+ For now we mark a function as reducible when its body is only a call to a loop
+ function. This situation often happens for simple functions whose body contains
+ a loop: we introduce an intermediate function for the loop body, and the
+ translation of the function itself simply calls the loop body. By marking
+ the function as reducible, we allow tactics like [simp] or [progress] to
+ see through the definition.
+ *)
+let compute_reducible (_ctx : trans_ctx) (transl : pure_fun_translation list) :
+ pure_fun_translation list =
+ let update_one (trans : pure_fun_translation) : pure_fun_translation =
+ match trans.f.body with
+ | None -> trans
+ | Some body -> (
+ (* Check if the body is exactly a call to a loop function.
+ Note that we check that the arguments are exactly the input
+ variables - otherwise we may not want the call to be reducible;
+ for instance when using the [progress] tactic we might want to
+ use a more specialized specification theorem. *)
+ let app, args = destruct_apps body.body in
+ match app.e with
+ | Qualif
+ {
+ id = FunOrOp (Fun (FromLlbc (FunId fid, Some _lp_id)));
+ generics = _;
+ }
+ when fid = FRegular trans.f.def_id ->
+ if
+ List.length body.inputs = List.length args
+ && List.for_all
+ (fun ((var, arg) : var * texpression) ->
+ match arg.e with
+ | Var var_id -> var_id = var.id
+ | _ -> false)
+ (List.combine body.inputs args)
+ then
+ let f =
+ { trans.f with backend_attributes = { reducible = true } }
+ in
+ { trans with f }
+ else trans
+ | _ -> trans)
+ in
+ List.map update_one transl
+
(** Apply all the micro-passes to a function.
As loops are initially directly integrated into the function definition,
@@ -2337,4 +2384,7 @@ let apply_passes_to_pure_fun_translations (ctx : trans_ctx)
(* Filter the useless inputs in the loop functions (loops are initially
parameterized by *all* the symbolic values in the context, because
they may access any of them). *)
- filter_loop_inputs ctx transl
+ let transl = filter_loop_inputs ctx transl in
+
+ (* Update the "reducible" attribute *)
+ compute_reducible ctx transl
diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml
index 87f1128d..ad61ddd1 100644
--- a/compiler/SymbolicToPure.ml
+++ b/compiler/SymbolicToPure.ml
@@ -1744,11 +1744,11 @@ and aloan_content_to_consumed (ctx : bs_ctx) (ectx : C.eval_ctx)
(* Ignore *)
None
-and aborrow_content_to_consumed (_ctx : bs_ctx) (bc : V.aborrow_content) :
+and aborrow_content_to_consumed (ctx : bs_ctx) (bc : V.aborrow_content) :
texpression option =
match bc with
| V.AMutBorrow (_, _, _) | ASharedBorrow (_, _) | AIgnoredMutBorrow (_, _) ->
- craise __FILE__ __LINE__ _ctx.span "Unreachable"
+ craise __FILE__ __LINE__ ctx.span "Unreachable"
| AEndedMutBorrow (_, _) ->
(* We collect consumed values: ignore *)
None
@@ -3894,12 +3894,14 @@ let translate_fun_decl (ctx : bs_ctx) (body : S.expression option) : fun_decl =
let loop_id = None in
(* Assemble the declaration *)
+ let backend_attributes = { reducible = false } in
let def : fun_decl =
{
def_id;
is_local = def.is_local;
span = def.item_meta.span;
kind = def.kind;
+ backend_attributes;
num_loops;
loop_id;
llbc_name;