diff options
author | Son Ho | 2022-12-17 10:27:12 +0100 |
---|---|---|
committer | Son HO | 2023-02-03 11:21:46 +0100 |
commit | 66638a2a96c7639553a340917b87e26d94265c5e (patch) | |
tree | a0219df7582ca17784135345924790dc26a7e315 /compiler/PureMicroPasses.ml | |
parent | 07621dcf488eef1c4a4ab797c21cc34ab474d225 (diff) |
Fix various issues with the generation of code for the loops
Diffstat (limited to '')
-rw-r--r-- | compiler/PureMicroPasses.ml | 383 |
1 files changed, 287 insertions, 96 deletions
diff --git a/compiler/PureMicroPasses.ml b/compiler/PureMicroPasses.ml index 87ab4609..335336be 100644 --- a/compiler/PureMicroPasses.ml +++ b/compiler/PureMicroPasses.ml @@ -432,12 +432,34 @@ let compute_pretty_names (def : fun_decl) : fun_decl = (ctx, Switch (scrut, body)) (* *) and update_loop (loop : loop) (ctx : pn_ctx) : pn_ctx * expression = - let { fun_end; loop_id; inputs; inputs_lvs; loop_body } = loop in + let { + fun_end; + loop_id; + fuel0; + fuel; + input_state; + inputs; + inputs_lvs; + loop_body; + } = + loop + in let ctx, fun_end = update_texpression fun_end ctx in let ctx, loop_body = update_texpression loop_body ctx in let inputs = List.map (fun input -> update_var ctx input None) inputs in let inputs_lvs = List.map (update_typed_pattern ctx) inputs_lvs in - let loop = { fun_end; loop_id; inputs; inputs_lvs; loop_body } in + let loop = + { + fun_end; + loop_id; + fuel0; + fuel; + input_state; + inputs; + inputs_lvs; + loop_body; + } + in (ctx, Loop loop) (* *) and update_meta (meta : meta) (e : texpression) (ctx : pn_ctx) : @@ -972,6 +994,160 @@ let filter_if_backward_with_no_outputs (def : fun_decl) : fun_decl option = then None else Some def +(** Retrieve the loop definitions from the function definition. + + {!SymbolicToPure} generates an AST in which the loop bodies are part of + the function body (see the {!Pure.Loop} node). This function extracts + those function bodies into independent definitions while removing + occurrences of the {!Pure.Loop} node. + *) +let decompose_loops (def : fun_decl) : fun_decl * fun_decl list = + (* Store the loops here *) + let loops = ref LoopId.Map.empty in + let expr_visitor = + object (self) + inherit [_] map_expression + + method! visit_Loop env loop = + let fun_sig = def.signature in + let fun_sig_info = fun_sig.info in + let fun_effect_info = fun_sig_info.effect_info in + + (* Generate the loop definition *) + let loop_effect_info = + { + stateful_group = fun_effect_info.stateful_group; + stateful = fun_effect_info.stateful; + can_fail = fun_effect_info.can_fail; + can_diverge = fun_effect_info.can_diverge; + is_rec = fun_effect_info.is_rec; + } + in + + let loop_sig_info = + let fuel = if !Config.use_fuel then 1 else 0 in + let num_inputs = List.length loop.inputs in + let num_fwd_inputs_with_fuel_no_state = fuel + num_inputs in + let num_fwd_inputs_with_fuel_with_state = + fun_sig_info.num_fwd_inputs_with_fuel_with_state + - fun_sig_info.num_fwd_inputs_with_fuel_no_state + in + { + has_fuel = !Config.use_fuel; + num_fwd_inputs_with_fuel_no_state; + num_fwd_inputs_with_fuel_with_state; + num_back_inputs_no_state = fun_sig_info.num_back_inputs_no_state; + num_back_inputs_with_state = fun_sig_info.num_back_inputs_with_state; + effect_info = loop_effect_info; + } + in + + let inputs_tys = + let fuel = if !Config.use_fuel then [ mk_fuel_ty ] else [] in + let fwd_inputs = List.map (fun (v : var) -> v.ty) loop.inputs in + let state = + Collections.List.subslice fun_sig.inputs + fun_sig_info.num_fwd_inputs_with_fuel_no_state + fun_sig_info.num_fwd_inputs_with_fuel_with_state + in + let _, back_inputs = + Collections.List.split_at fun_sig.inputs + fun_sig_info.num_fwd_inputs_with_fuel_with_state + in + List.concat [ fuel; fwd_inputs; state; back_inputs ] + in + + let loop_sig = + { + type_params = fun_sig.type_params; + inputs = inputs_tys; + output = fun_sig.output; + doutputs = fun_sig.doutputs; + info = loop_sig_info; + } + in + + let fuel_vars, inputs, inputs_lvs = + (* Introduce the fuel input *) + let fuel_vars, fuel0_var, fuel_lvs = + if !Config.use_fuel then + let fuel0_var = mk_fuel_var loop.fuel0 in + let fuel_lvs = mk_typed_pattern_from_var fuel0_var None in + (Some (loop.fuel0, loop.fuel), [ fuel0_var ], [ fuel_lvs ]) + else (None, [], []) + in + + (* Introduce the forward input state *) + let fwd_state_var, fwd_state_lvs = + assert (loop_effect_info.stateful = Option.is_some loop.input_state); + match loop.input_state with + | None -> ([], []) + | Some input_state -> + let state_var = mk_state_var input_state in + let state_lvs = mk_typed_pattern_from_var state_var None in + ([ state_var ], [ state_lvs ]) + in + + (* Introduce the additional backward inputs *) + let fun_body = Option.get def.body in + let _, back_inputs = + Collections.List.split_at fun_body.inputs + fun_sig_info.num_fwd_inputs_with_fuel_with_state + in + let _, back_inputs_lvs = + Collections.List.split_at fun_body.inputs_lvs + fun_sig_info.num_fwd_inputs_with_fuel_with_state + in + + let inputs = + List.concat [ fuel0_var; fwd_state_var; loop.inputs; back_inputs ] + in + let inputs_lvs = + List.concat + [ fuel_lvs; fwd_state_lvs; loop.inputs_lvs; back_inputs_lvs ] + in + (fuel_vars, inputs, inputs_lvs) + in + + (* Wrap the loop body in a match over the fuel *) + let loop_body = + match fuel_vars with + | None -> loop.loop_body + | Some (fuel0, fuel) -> + SymbolicToPure.wrap_in_match_fuel fuel0 fuel loop.loop_body + in + + let loop_body = { inputs; inputs_lvs; body = loop_body } in + + let loop_def = + { + def_id = def.def_id; + num_loops = 0; + loop_id = Some loop.loop_id; + back_id = def.back_id; + basename = def.basename; + signature = loop_sig; + is_global_decl_body = def.is_global_decl_body; + body = Some loop_body; + } + in + (* Store the loop definition *) + loops := LoopId.Map.add_strict loop.loop_id loop_def !loops; + + (* Update the current expression to remove the [Loop] node, and continue *) + (self#visit_texpression env loop.fun_end).e + end + in + + match def.body with + | None -> (def, []) + | Some body -> + let body_expr = expr_visitor#visit_texpression () body.body in + let body = { body with body = body_expr } in + let def = { def with body = Some body } in + let loops = List.map snd (LoopId.Map.bindings !loops) in + (def, loops) + (** Return [false] if the forward function is useless and should be filtered. - a forward function with no output (comes from a Rust function with @@ -989,7 +1165,7 @@ let filter_if_backward_with_no_outputs (def : fun_decl) : fun_decl option = altogether. *) let keep_forward (trans : pure_fun_translation) : bool = - let fwd, backs = trans in + let (fwd, _), backs = trans in (* Note that at this point, the output types are no longer seen as tuples: * they should be lists of length 1. *) if @@ -1306,13 +1482,110 @@ let unfold_monadic_let_bindings (_ctx : trans_ctx) (def : fun_decl) : fun_decl = (* Return *) { def with body = Some body } +(** Auxiliary function for {!apply_passes_to_def} *) +let apply_end_passes_to_def (ctx : trans_ctx) (def : fun_decl) : fun_decl = + (* Convert the unit variables to [()] if they are used as right-values or + * [_] if they are used as left values. *) + let def = unit_vars_to_unit def in + log#ldebug + (lazy ("unit_vars_to_unit:\n\n" ^ fun_decl_to_string ctx def ^ "\n")); + + (* Inline the useless variable reassignments *) + let inline_named_vars = true in + let inline_pure = true in + let def = + inline_useless_var_reassignments inline_named_vars inline_pure def + in + log#ldebug + (lazy + ("inline_useless_var_assignments:\n\n" ^ fun_decl_to_string ctx def ^ "\n")); + + (* Eliminate the box functions - note that the "box" types were eliminated + * during the symbolic to pure phase: see the comments for [eliminate_box_functions] *) + let def = eliminate_box_functions ctx def in + log#ldebug + (lazy ("eliminate_box_functions:\n\n" ^ fun_decl_to_string ctx def ^ "\n")); + + (* Filter the useless variables, assignments, function calls, etc. *) + let def = filter_useless !Config.filter_useless_monadic_calls ctx def in + log#ldebug (lazy ("filter_useless:\n\n" ^ fun_decl_to_string ctx def ^ "\n")); + + (* Simplify the aggregated ADTs. + + Ex.: + {[ + (* type struct = { f0 : nat; f1 : nat } *) + + Mkstruct x.f0 x.f1 ~~> x + ]} + *) + let def = simplify_aggregates ctx def in + log#ldebug + (lazy ("simplify_aggregates:\n\n" ^ fun_decl_to_string ctx def ^ "\n")); + + (* Decompose the monadic let-bindings - used by Coq *) + let def = + if !Config.decompose_monadic_let_bindings then ( + let def = decompose_monadic_let_bindings ctx def in + log#ldebug + (lazy + ("decompose_monadic_let_bindings:\n\n" ^ fun_decl_to_string ctx def + ^ "\n")); + def) + else ( + log#ldebug + (lazy + "ignoring decompose_monadic_let_bindings due to the configuration\n"); + def) + in + + (* Decompose nested let-patterns *) + let def = + if !Config.decompose_nested_let_patterns then ( + let def = decompose_nested_let_patterns ctx def in + log#ldebug + (lazy + ("decompose_nested_let_patterns:\n\n" ^ fun_decl_to_string ctx def + ^ "\n")); + def) + else ( + log#ldebug + (lazy + "ignoring decompose_nested_let_patterns due to the configuration\n"); + def) + in + + (* Unfold the monadic let-bindings *) + let def = + if !Config.unfold_monadic_let_bindings then ( + let def = unfold_monadic_let_bindings ctx def in + log#ldebug + (lazy + ("unfold_monadic_let_bindings:\n\n" ^ fun_decl_to_string ctx def + ^ "\n")); + def) + else ( + log#ldebug + (lazy "ignoring unfold_monadic_let_bindings due to the configuration\n"); + def) + in + + (* We are done *) + def + (** Apply all the micro-passes to a function. + As loops are initially directly integrated into the function definition, + {!apply_passes_to_def} extracts those loops definitions from the body; + it thus returns the pair: (function def, loop defs). See {!decompose_loops} + for more information. + Will return [None] if the function is a backward function with no outputs. [ctx]: used only for printing. *) -let apply_passes_to_def (ctx : trans_ctx) (def : fun_decl) : fun_decl option = +let apply_passes_to_def (ctx : trans_ctx) (def : fun_decl) : + (fun_decl * fun_decl list) option = (* Debug *) log#ldebug (lazy @@ -1347,101 +1620,19 @@ let apply_passes_to_def (ctx : trans_ctx) (def : fun_decl) : fun_decl option = match def with | None -> None | Some def -> - (* Convert the unit variables to [()] if they are used as right-values or - * [_] if they are used as left values. *) - let def = unit_vars_to_unit def in - log#ldebug - (lazy ("unit_vars_to_unit:\n\n" ^ fun_decl_to_string ctx def ^ "\n")); - - (* Inline the useless variable reassignments *) - let inline_named_vars = true in - let inline_pure = true in - let def = - inline_useless_var_reassignments inline_named_vars inline_pure def - in - log#ldebug - (lazy - ("inline_useless_var_assignments:\n\n" ^ fun_decl_to_string ctx def - ^ "\n")); - - (* Eliminate the box functions - note that the "box" types were eliminated - * during the symbolic to pure phase: see the comments for [eliminate_box_functions] *) - let def = eliminate_box_functions ctx def in - log#ldebug - (lazy - ("eliminate_box_functions:\n\n" ^ fun_decl_to_string ctx def ^ "\n")); - - (* Filter the useless variables, assignments, function calls, etc. *) - let def = filter_useless !Config.filter_useless_monadic_calls ctx def in - log#ldebug - (lazy ("filter_useless:\n\n" ^ fun_decl_to_string ctx def ^ "\n")); + (* Extract the loop definitions by removing the {!Loop} node *) + let def, loops = decompose_loops def in - (* Simplify the aggregated ADTs. - - Ex.: - {[ - (* type struct = { f0 : nat; f1 : nat } *) - - Mkstruct x.f0 x.f1 ~~> x - ]} - *) - let def = simplify_aggregates ctx def in - log#ldebug - (lazy ("simplify_aggregates:\n\n" ^ fun_decl_to_string ctx def ^ "\n")); - - (* Decompose the monadic let-bindings - used by Coq *) - let def = - if !Config.decompose_monadic_let_bindings then ( - let def = decompose_monadic_let_bindings ctx def in - log#ldebug - (lazy - ("decompose_monadic_let_bindings:\n\n" - ^ fun_decl_to_string ctx def ^ "\n")); - def) - else ( - log#ldebug - (lazy - "ignoring decompose_monadic_let_bindings due to the configuration\n"); - def) - in - - (* Decompose nested let-patterns *) - let def = - if !Config.decompose_nested_let_patterns then ( - let def = decompose_nested_let_patterns ctx def in - log#ldebug - (lazy - ("decompose_nested_let_patterns:\n\n" ^ fun_decl_to_string ctx def - ^ "\n")); - def) - else ( - log#ldebug - (lazy - "ignoring decompose_nested_let_patterns due to the configuration\n"); - def) - in - - (* Unfold the monadic let-bindings *) - let def = - if !Config.unfold_monadic_let_bindings then ( - let def = unfold_monadic_let_bindings ctx def in - log#ldebug - (lazy - ("unfold_monadic_let_bindings:\n\n" ^ fun_decl_to_string ctx def - ^ "\n")); - def) - else ( - log#ldebug - (lazy - "ignoring unfold_monadic_let_bindings due to the configuration\n"); - def) - in - - (* We are done *) - Some def + (* Apply the remaining passes *) + let def = apply_end_passes_to_def ctx def in + let loops = List.map (apply_end_passes_to_def ctx) loops in + Some (def, loops) (** Return the forward/backward translations on which we applied the micro-passes. + This function also extracts the loop definitions from the function body + (see {!decompose_loops}). + Also returns a boolean indicating whether the forward function should be kept or not (because useful/useless - [true] means we need to keep the forward function). @@ -1450,7 +1641,7 @@ let apply_passes_to_def (ctx : trans_ctx) (def : fun_decl) : fun_decl option = functions: keeping it is not necessary but more convenient. *) let apply_passes_to_pure_fun_translation (ctx : trans_ctx) - (trans : pure_fun_translation) : bool * pure_fun_translation = + (trans : fun_decl * fun_decl list) : bool * pure_fun_translation = (* Apply the passes to the individual functions *) let forward, backwards = trans in let forward = Option.get (apply_passes_to_def ctx forward) in |