summaryrefslogtreecommitdiff
path: root/compiler/PureMicroPasses.ml
diff options
context:
space:
mode:
authorSon Ho2022-12-17 10:27:12 +0100
committerSon HO2023-02-03 11:21:46 +0100
commit66638a2a96c7639553a340917b87e26d94265c5e (patch)
treea0219df7582ca17784135345924790dc26a7e315 /compiler/PureMicroPasses.ml
parent07621dcf488eef1c4a4ab797c21cc34ab474d225 (diff)
Fix various issues with the generation of code for the loops
Diffstat (limited to '')
-rw-r--r--compiler/PureMicroPasses.ml383
1 files changed, 287 insertions, 96 deletions
diff --git a/compiler/PureMicroPasses.ml b/compiler/PureMicroPasses.ml
index 87ab4609..335336be 100644
--- a/compiler/PureMicroPasses.ml
+++ b/compiler/PureMicroPasses.ml
@@ -432,12 +432,34 @@ let compute_pretty_names (def : fun_decl) : fun_decl =
(ctx, Switch (scrut, body))
(* *)
and update_loop (loop : loop) (ctx : pn_ctx) : pn_ctx * expression =
- let { fun_end; loop_id; inputs; inputs_lvs; loop_body } = loop in
+ let {
+ fun_end;
+ loop_id;
+ fuel0;
+ fuel;
+ input_state;
+ inputs;
+ inputs_lvs;
+ loop_body;
+ } =
+ loop
+ in
let ctx, fun_end = update_texpression fun_end ctx in
let ctx, loop_body = update_texpression loop_body ctx in
let inputs = List.map (fun input -> update_var ctx input None) inputs in
let inputs_lvs = List.map (update_typed_pattern ctx) inputs_lvs in
- let loop = { fun_end; loop_id; inputs; inputs_lvs; loop_body } in
+ let loop =
+ {
+ fun_end;
+ loop_id;
+ fuel0;
+ fuel;
+ input_state;
+ inputs;
+ inputs_lvs;
+ loop_body;
+ }
+ in
(ctx, Loop loop)
(* *)
and update_meta (meta : meta) (e : texpression) (ctx : pn_ctx) :
@@ -972,6 +994,160 @@ let filter_if_backward_with_no_outputs (def : fun_decl) : fun_decl option =
then None
else Some def
+(** Retrieve the loop definitions from the function definition.
+
+ {!SymbolicToPure} generates an AST in which the loop bodies are part of
+ the function body (see the {!Pure.Loop} node). This function extracts
+ those function bodies into independent definitions while removing
+ occurrences of the {!Pure.Loop} node.
+ *)
+let decompose_loops (def : fun_decl) : fun_decl * fun_decl list =
+ (* Store the loops here *)
+ let loops = ref LoopId.Map.empty in
+ let expr_visitor =
+ object (self)
+ inherit [_] map_expression
+
+ method! visit_Loop env loop =
+ let fun_sig = def.signature in
+ let fun_sig_info = fun_sig.info in
+ let fun_effect_info = fun_sig_info.effect_info in
+
+ (* Generate the loop definition *)
+ let loop_effect_info =
+ {
+ stateful_group = fun_effect_info.stateful_group;
+ stateful = fun_effect_info.stateful;
+ can_fail = fun_effect_info.can_fail;
+ can_diverge = fun_effect_info.can_diverge;
+ is_rec = fun_effect_info.is_rec;
+ }
+ in
+
+ let loop_sig_info =
+ let fuel = if !Config.use_fuel then 1 else 0 in
+ let num_inputs = List.length loop.inputs in
+ let num_fwd_inputs_with_fuel_no_state = fuel + num_inputs in
+ let num_fwd_inputs_with_fuel_with_state =
+ fun_sig_info.num_fwd_inputs_with_fuel_with_state
+ - fun_sig_info.num_fwd_inputs_with_fuel_no_state
+ in
+ {
+ has_fuel = !Config.use_fuel;
+ num_fwd_inputs_with_fuel_no_state;
+ num_fwd_inputs_with_fuel_with_state;
+ num_back_inputs_no_state = fun_sig_info.num_back_inputs_no_state;
+ num_back_inputs_with_state = fun_sig_info.num_back_inputs_with_state;
+ effect_info = loop_effect_info;
+ }
+ in
+
+ let inputs_tys =
+ let fuel = if !Config.use_fuel then [ mk_fuel_ty ] else [] in
+ let fwd_inputs = List.map (fun (v : var) -> v.ty) loop.inputs in
+ let state =
+ Collections.List.subslice fun_sig.inputs
+ fun_sig_info.num_fwd_inputs_with_fuel_no_state
+ fun_sig_info.num_fwd_inputs_with_fuel_with_state
+ in
+ let _, back_inputs =
+ Collections.List.split_at fun_sig.inputs
+ fun_sig_info.num_fwd_inputs_with_fuel_with_state
+ in
+ List.concat [ fuel; fwd_inputs; state; back_inputs ]
+ in
+
+ let loop_sig =
+ {
+ type_params = fun_sig.type_params;
+ inputs = inputs_tys;
+ output = fun_sig.output;
+ doutputs = fun_sig.doutputs;
+ info = loop_sig_info;
+ }
+ in
+
+ let fuel_vars, inputs, inputs_lvs =
+ (* Introduce the fuel input *)
+ let fuel_vars, fuel0_var, fuel_lvs =
+ if !Config.use_fuel then
+ let fuel0_var = mk_fuel_var loop.fuel0 in
+ let fuel_lvs = mk_typed_pattern_from_var fuel0_var None in
+ (Some (loop.fuel0, loop.fuel), [ fuel0_var ], [ fuel_lvs ])
+ else (None, [], [])
+ in
+
+ (* Introduce the forward input state *)
+ let fwd_state_var, fwd_state_lvs =
+ assert (loop_effect_info.stateful = Option.is_some loop.input_state);
+ match loop.input_state with
+ | None -> ([], [])
+ | Some input_state ->
+ let state_var = mk_state_var input_state in
+ let state_lvs = mk_typed_pattern_from_var state_var None in
+ ([ state_var ], [ state_lvs ])
+ in
+
+ (* Introduce the additional backward inputs *)
+ let fun_body = Option.get def.body in
+ let _, back_inputs =
+ Collections.List.split_at fun_body.inputs
+ fun_sig_info.num_fwd_inputs_with_fuel_with_state
+ in
+ let _, back_inputs_lvs =
+ Collections.List.split_at fun_body.inputs_lvs
+ fun_sig_info.num_fwd_inputs_with_fuel_with_state
+ in
+
+ let inputs =
+ List.concat [ fuel0_var; fwd_state_var; loop.inputs; back_inputs ]
+ in
+ let inputs_lvs =
+ List.concat
+ [ fuel_lvs; fwd_state_lvs; loop.inputs_lvs; back_inputs_lvs ]
+ in
+ (fuel_vars, inputs, inputs_lvs)
+ in
+
+ (* Wrap the loop body in a match over the fuel *)
+ let loop_body =
+ match fuel_vars with
+ | None -> loop.loop_body
+ | Some (fuel0, fuel) ->
+ SymbolicToPure.wrap_in_match_fuel fuel0 fuel loop.loop_body
+ in
+
+ let loop_body = { inputs; inputs_lvs; body = loop_body } in
+
+ let loop_def =
+ {
+ def_id = def.def_id;
+ num_loops = 0;
+ loop_id = Some loop.loop_id;
+ back_id = def.back_id;
+ basename = def.basename;
+ signature = loop_sig;
+ is_global_decl_body = def.is_global_decl_body;
+ body = Some loop_body;
+ }
+ in
+ (* Store the loop definition *)
+ loops := LoopId.Map.add_strict loop.loop_id loop_def !loops;
+
+ (* Update the current expression to remove the [Loop] node, and continue *)
+ (self#visit_texpression env loop.fun_end).e
+ end
+ in
+
+ match def.body with
+ | None -> (def, [])
+ | Some body ->
+ let body_expr = expr_visitor#visit_texpression () body.body in
+ let body = { body with body = body_expr } in
+ let def = { def with body = Some body } in
+ let loops = List.map snd (LoopId.Map.bindings !loops) in
+ (def, loops)
+
(** Return [false] if the forward function is useless and should be filtered.
- a forward function with no output (comes from a Rust function with
@@ -989,7 +1165,7 @@ let filter_if_backward_with_no_outputs (def : fun_decl) : fun_decl option =
altogether.
*)
let keep_forward (trans : pure_fun_translation) : bool =
- let fwd, backs = trans in
+ let (fwd, _), backs = trans in
(* Note that at this point, the output types are no longer seen as tuples:
* they should be lists of length 1. *)
if
@@ -1306,13 +1482,110 @@ let unfold_monadic_let_bindings (_ctx : trans_ctx) (def : fun_decl) : fun_decl =
(* Return *)
{ def with body = Some body }
+(** Auxiliary function for {!apply_passes_to_def} *)
+let apply_end_passes_to_def (ctx : trans_ctx) (def : fun_decl) : fun_decl =
+ (* Convert the unit variables to [()] if they are used as right-values or
+ * [_] if they are used as left values. *)
+ let def = unit_vars_to_unit def in
+ log#ldebug
+ (lazy ("unit_vars_to_unit:\n\n" ^ fun_decl_to_string ctx def ^ "\n"));
+
+ (* Inline the useless variable reassignments *)
+ let inline_named_vars = true in
+ let inline_pure = true in
+ let def =
+ inline_useless_var_reassignments inline_named_vars inline_pure def
+ in
+ log#ldebug
+ (lazy
+ ("inline_useless_var_assignments:\n\n" ^ fun_decl_to_string ctx def ^ "\n"));
+
+ (* Eliminate the box functions - note that the "box" types were eliminated
+ * during the symbolic to pure phase: see the comments for [eliminate_box_functions] *)
+ let def = eliminate_box_functions ctx def in
+ log#ldebug
+ (lazy ("eliminate_box_functions:\n\n" ^ fun_decl_to_string ctx def ^ "\n"));
+
+ (* Filter the useless variables, assignments, function calls, etc. *)
+ let def = filter_useless !Config.filter_useless_monadic_calls ctx def in
+ log#ldebug (lazy ("filter_useless:\n\n" ^ fun_decl_to_string ctx def ^ "\n"));
+
+ (* Simplify the aggregated ADTs.
+
+ Ex.:
+ {[
+ (* type struct = { f0 : nat; f1 : nat } *)
+
+ Mkstruct x.f0 x.f1 ~~> x
+ ]}
+ *)
+ let def = simplify_aggregates ctx def in
+ log#ldebug
+ (lazy ("simplify_aggregates:\n\n" ^ fun_decl_to_string ctx def ^ "\n"));
+
+ (* Decompose the monadic let-bindings - used by Coq *)
+ let def =
+ if !Config.decompose_monadic_let_bindings then (
+ let def = decompose_monadic_let_bindings ctx def in
+ log#ldebug
+ (lazy
+ ("decompose_monadic_let_bindings:\n\n" ^ fun_decl_to_string ctx def
+ ^ "\n"));
+ def)
+ else (
+ log#ldebug
+ (lazy
+ "ignoring decompose_monadic_let_bindings due to the configuration\n");
+ def)
+ in
+
+ (* Decompose nested let-patterns *)
+ let def =
+ if !Config.decompose_nested_let_patterns then (
+ let def = decompose_nested_let_patterns ctx def in
+ log#ldebug
+ (lazy
+ ("decompose_nested_let_patterns:\n\n" ^ fun_decl_to_string ctx def
+ ^ "\n"));
+ def)
+ else (
+ log#ldebug
+ (lazy
+ "ignoring decompose_nested_let_patterns due to the configuration\n");
+ def)
+ in
+
+ (* Unfold the monadic let-bindings *)
+ let def =
+ if !Config.unfold_monadic_let_bindings then (
+ let def = unfold_monadic_let_bindings ctx def in
+ log#ldebug
+ (lazy
+ ("unfold_monadic_let_bindings:\n\n" ^ fun_decl_to_string ctx def
+ ^ "\n"));
+ def)
+ else (
+ log#ldebug
+ (lazy "ignoring unfold_monadic_let_bindings due to the configuration\n");
+ def)
+ in
+
+ (* We are done *)
+ def
+
(** Apply all the micro-passes to a function.
+ As loops are initially directly integrated into the function definition,
+ {!apply_passes_to_def} extracts those loops definitions from the body;
+ it thus returns the pair: (function def, loop defs). See {!decompose_loops}
+ for more information.
+
Will return [None] if the function is a backward function with no outputs.
[ctx]: used only for printing.
*)
-let apply_passes_to_def (ctx : trans_ctx) (def : fun_decl) : fun_decl option =
+let apply_passes_to_def (ctx : trans_ctx) (def : fun_decl) :
+ (fun_decl * fun_decl list) option =
(* Debug *)
log#ldebug
(lazy
@@ -1347,101 +1620,19 @@ let apply_passes_to_def (ctx : trans_ctx) (def : fun_decl) : fun_decl option =
match def with
| None -> None
| Some def ->
- (* Convert the unit variables to [()] if they are used as right-values or
- * [_] if they are used as left values. *)
- let def = unit_vars_to_unit def in
- log#ldebug
- (lazy ("unit_vars_to_unit:\n\n" ^ fun_decl_to_string ctx def ^ "\n"));
-
- (* Inline the useless variable reassignments *)
- let inline_named_vars = true in
- let inline_pure = true in
- let def =
- inline_useless_var_reassignments inline_named_vars inline_pure def
- in
- log#ldebug
- (lazy
- ("inline_useless_var_assignments:\n\n" ^ fun_decl_to_string ctx def
- ^ "\n"));
-
- (* Eliminate the box functions - note that the "box" types were eliminated
- * during the symbolic to pure phase: see the comments for [eliminate_box_functions] *)
- let def = eliminate_box_functions ctx def in
- log#ldebug
- (lazy
- ("eliminate_box_functions:\n\n" ^ fun_decl_to_string ctx def ^ "\n"));
-
- (* Filter the useless variables, assignments, function calls, etc. *)
- let def = filter_useless !Config.filter_useless_monadic_calls ctx def in
- log#ldebug
- (lazy ("filter_useless:\n\n" ^ fun_decl_to_string ctx def ^ "\n"));
+ (* Extract the loop definitions by removing the {!Loop} node *)
+ let def, loops = decompose_loops def in
- (* Simplify the aggregated ADTs.
-
- Ex.:
- {[
- (* type struct = { f0 : nat; f1 : nat } *)
-
- Mkstruct x.f0 x.f1 ~~> x
- ]}
- *)
- let def = simplify_aggregates ctx def in
- log#ldebug
- (lazy ("simplify_aggregates:\n\n" ^ fun_decl_to_string ctx def ^ "\n"));
-
- (* Decompose the monadic let-bindings - used by Coq *)
- let def =
- if !Config.decompose_monadic_let_bindings then (
- let def = decompose_monadic_let_bindings ctx def in
- log#ldebug
- (lazy
- ("decompose_monadic_let_bindings:\n\n"
- ^ fun_decl_to_string ctx def ^ "\n"));
- def)
- else (
- log#ldebug
- (lazy
- "ignoring decompose_monadic_let_bindings due to the configuration\n");
- def)
- in
-
- (* Decompose nested let-patterns *)
- let def =
- if !Config.decompose_nested_let_patterns then (
- let def = decompose_nested_let_patterns ctx def in
- log#ldebug
- (lazy
- ("decompose_nested_let_patterns:\n\n" ^ fun_decl_to_string ctx def
- ^ "\n"));
- def)
- else (
- log#ldebug
- (lazy
- "ignoring decompose_nested_let_patterns due to the configuration\n");
- def)
- in
-
- (* Unfold the monadic let-bindings *)
- let def =
- if !Config.unfold_monadic_let_bindings then (
- let def = unfold_monadic_let_bindings ctx def in
- log#ldebug
- (lazy
- ("unfold_monadic_let_bindings:\n\n" ^ fun_decl_to_string ctx def
- ^ "\n"));
- def)
- else (
- log#ldebug
- (lazy
- "ignoring unfold_monadic_let_bindings due to the configuration\n");
- def)
- in
-
- (* We are done *)
- Some def
+ (* Apply the remaining passes *)
+ let def = apply_end_passes_to_def ctx def in
+ let loops = List.map (apply_end_passes_to_def ctx) loops in
+ Some (def, loops)
(** Return the forward/backward translations on which we applied the micro-passes.
+ This function also extracts the loop definitions from the function body
+ (see {!decompose_loops}).
+
Also returns a boolean indicating whether the forward function should be kept
or not (because useful/useless - [true] means we need to keep the forward
function).
@@ -1450,7 +1641,7 @@ let apply_passes_to_def (ctx : trans_ctx) (def : fun_decl) : fun_decl option =
functions: keeping it is not necessary but more convenient.
*)
let apply_passes_to_pure_fun_translation (ctx : trans_ctx)
- (trans : pure_fun_translation) : bool * pure_fun_translation =
+ (trans : fun_decl * fun_decl list) : bool * pure_fun_translation =
(* Apply the passes to the individual functions *)
let forward, backwards = trans in
let forward = Option.get (apply_passes_to_def ctx forward) in