From f4739fba4be95818ca01776837c8d610e443a45b Mon Sep 17 00:00:00 2001 From: Son Ho Date: Mon, 17 Jun 2024 07:14:52 +0200 Subject: Automatically add a @[reducible] attribute to some generated functions --- compiler/Extract.ml | 4 ++++ compiler/Pure.ml | 7 ++++++ compiler/PureMicroPasses.ml | 52 ++++++++++++++++++++++++++++++++++++++++++++- compiler/SymbolicToPure.ml | 6 ++++-- 4 files changed, 66 insertions(+), 3 deletions(-) diff --git a/compiler/Extract.ml b/compiler/Extract.ml index 4acf3f99..b1adb936 100644 --- a/compiler/Extract.ml +++ b/compiler/Extract.ml @@ -1448,6 +1448,10 @@ let extract_fun_decl_gen (ctx : extraction_ctx) (fmt : F.formatter) (* Open two boxes for the definition, so that whenever possible it gets printed on * one line and indents are correct *) F.pp_open_hvbox fmt 0; + (* Print the attributes *) + if def.backend_attributes.reducible && backend () = Lean then ( + F.pp_print_string fmt "@[reducible]"; + F.pp_print_space fmt ()); F.pp_open_vbox fmt ctx.indent_incr; (* For HOL4: we may need to put parentheses around the definition *) let parenthesize = backend () = HOL4 && decl_is_not_last_from_group kind in diff --git a/compiler/Pure.ml b/compiler/Pure.ml index d07b8cfa..f7445575 100644 --- a/compiler/Pure.ml +++ b/compiler/Pure.ml @@ -1077,11 +1077,18 @@ type fun_body = { type item_kind = A.item_kind [@@deriving show] +(** Attributes to add to the generated code *) +type backend_attributes = { + reducible : bool; (** Lean "reducible" attribute *) +} +[@@deriving show] + type fun_decl = { def_id : FunDeclId.id; is_local : bool; span : span; kind : item_kind; + backend_attributes : backend_attributes; num_loops : int; (** The number of loops in the parent forward function (basically the number of loops appearing in the original Rust functions, unless some loops are diff --git a/compiler/PureMicroPasses.ml b/compiler/PureMicroPasses.ml index b0cba250..8b95f729 100644 --- a/compiler/PureMicroPasses.ml +++ b/compiler/PureMicroPasses.ml @@ -1498,6 +1498,7 @@ let decompose_loops (_ctx : trans_ctx) (def : fun_decl) : is_local = def.is_local; span = loop.span; kind = def.kind; + backend_attributes = def.backend_attributes; num_loops; loop_id = Some loop.loop_id; llbc_name = def.llbc_name; @@ -2277,6 +2278,52 @@ let filter_loop_inputs (ctx : trans_ctx) (transl : pure_fun_translation list) : (* Return *) transl +(** Update the [reducible] attribute. + + For now we mark a function as reducible when its body is only a call to a loop + function. This situation often happens for simple functions whose body contains + a loop: we introduce an intermediate function for the loop body, and the + translation of the function itself simply calls the loop body. By marking + the function as reducible, we allow tactics like [simp] or [progress] to + see through the definition. + *) +let compute_reducible (_ctx : trans_ctx) (transl : pure_fun_translation list) : + pure_fun_translation list = + let update_one (trans : pure_fun_translation) : pure_fun_translation = + match trans.f.body with + | None -> trans + | Some body -> ( + (* Check if the body is exactly a call to a loop function. + Note that we check that the arguments are exactly the input + variables - otherwise we may not want the call to be reducible; + for instance when using the [progress] tactic we might want to + use a more specialized specification theorem. *) + let app, args = destruct_apps body.body in + match app.e with + | Qualif + { + id = FunOrOp (Fun (FromLlbc (FunId fid, Some _lp_id))); + generics = _; + } + when fid = FRegular trans.f.def_id -> + if + List.length body.inputs = List.length args + && List.for_all + (fun ((var, arg) : var * texpression) -> + match arg.e with + | Var var_id -> var_id = var.id + | _ -> false) + (List.combine body.inputs args) + then + let f = + { trans.f with backend_attributes = { reducible = true } } + in + { trans with f } + else trans + | _ -> trans) + in + List.map update_one transl + (** Apply all the micro-passes to a function. As loops are initially directly integrated into the function definition, @@ -2337,4 +2384,7 @@ let apply_passes_to_pure_fun_translations (ctx : trans_ctx) (* Filter the useless inputs in the loop functions (loops are initially parameterized by *all* the symbolic values in the context, because they may access any of them). *) - filter_loop_inputs ctx transl + let transl = filter_loop_inputs ctx transl in + + (* Update the "reducible" attribute *) + compute_reducible ctx transl diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml index 87f1128d..ad61ddd1 100644 --- a/compiler/SymbolicToPure.ml +++ b/compiler/SymbolicToPure.ml @@ -1744,11 +1744,11 @@ and aloan_content_to_consumed (ctx : bs_ctx) (ectx : C.eval_ctx) (* Ignore *) None -and aborrow_content_to_consumed (_ctx : bs_ctx) (bc : V.aborrow_content) : +and aborrow_content_to_consumed (ctx : bs_ctx) (bc : V.aborrow_content) : texpression option = match bc with | V.AMutBorrow (_, _, _) | ASharedBorrow (_, _) | AIgnoredMutBorrow (_, _) -> - craise __FILE__ __LINE__ _ctx.span "Unreachable" + craise __FILE__ __LINE__ ctx.span "Unreachable" | AEndedMutBorrow (_, _) -> (* We collect consumed values: ignore *) None @@ -3894,12 +3894,14 @@ let translate_fun_decl (ctx : bs_ctx) (body : S.expression option) : fun_decl = let loop_id = None in (* Assemble the declaration *) + let backend_attributes = { reducible = false } in let def : fun_decl = { def_id; is_local = def.is_local; span = def.item_meta.span; kind = def.kind; + backend_attributes; num_loops; loop_id; llbc_name; -- cgit v1.2.3