summaryrefslogtreecommitdiff
path: root/compiler
diff options
context:
space:
mode:
Diffstat (limited to 'compiler')
-rw-r--r--compiler/PureMicroPasses.ml281
1 files changed, 149 insertions, 132 deletions
diff --git a/compiler/PureMicroPasses.ml b/compiler/PureMicroPasses.ml
index 3937db0a..ae791135 100644
--- a/compiler/PureMicroPasses.ml
+++ b/compiler/PureMicroPasses.ml
@@ -1049,152 +1049,169 @@ let filter_if_backward_with_no_outputs (def : fun_decl) : fun_decl option =
occurrences of the {!Pure.Loop} node.
*)
let decompose_loops (def : fun_decl) : fun_decl * fun_decl list =
- (* Store the loops here *)
- let loops = ref LoopId.Map.empty in
- let expr_visitor =
- object (self)
- inherit [_] map_expression
-
- method! visit_Loop env loop =
- let fun_sig = def.signature in
- let fun_sig_info = fun_sig.info in
- let fun_effect_info = fun_sig_info.effect_info in
-
- (* Generate the loop definition *)
- let loop_effect_info =
- {
- stateful_group = fun_effect_info.stateful_group;
- stateful = fun_effect_info.stateful;
- can_fail = fun_effect_info.can_fail;
- can_diverge = fun_effect_info.can_diverge;
- is_rec = fun_effect_info.is_rec;
- }
- in
+ match def.body with
+ | None -> (def, [])
+ | Some body ->
+ (* Count the number of loops *)
+ let loops = ref LoopId.Set.empty in
+ let expr_visitor =
+ object
+ inherit [_] iter_expression as super
+
+ method! visit_Loop env loop =
+ loops := LoopId.Set.add loop.loop_id !loops;
+ super#visit_Loop env loop
+ end
+ in
+ expr_visitor#visit_texpression () body.body;
+ let num_loops = LoopId.Set.cardinal !loops in
- let loop_sig_info =
- let fuel = if !Config.use_fuel then 1 else 0 in
- let num_inputs = List.length loop.inputs in
- let num_fwd_inputs_with_fuel_no_state = fuel + num_inputs in
- let fwd_state =
- fun_sig_info.num_fwd_inputs_with_fuel_with_state
- - fun_sig_info.num_fwd_inputs_with_fuel_no_state
- in
- let num_fwd_inputs_with_fuel_with_state =
- num_fwd_inputs_with_fuel_no_state + fwd_state
- in
- {
- has_fuel = !Config.use_fuel;
- num_fwd_inputs_with_fuel_no_state;
- num_fwd_inputs_with_fuel_with_state;
- num_back_inputs_no_state = fun_sig_info.num_back_inputs_no_state;
- num_back_inputs_with_state = fun_sig_info.num_back_inputs_with_state;
- effect_info = loop_effect_info;
- }
- in
+ (* Store the loops here *)
+ let loops = ref LoopId.Map.empty in
+ let expr_visitor =
+ object (self)
+ inherit [_] map_expression
+
+ method! visit_Loop env loop =
+ let fun_sig = def.signature in
+ let fun_sig_info = fun_sig.info in
+ let fun_effect_info = fun_sig_info.effect_info in
+
+ (* Generate the loop definition *)
+ let loop_effect_info =
+ {
+ stateful_group = fun_effect_info.stateful_group;
+ stateful = fun_effect_info.stateful;
+ can_fail = fun_effect_info.can_fail;
+ can_diverge = fun_effect_info.can_diverge;
+ is_rec = fun_effect_info.is_rec;
+ }
+ in
- let inputs_tys =
- let fuel = if !Config.use_fuel then [ mk_fuel_ty ] else [] in
- let fwd_inputs = List.map (fun (v : var) -> v.ty) loop.inputs in
- let state =
- Collections.List.subslice fun_sig.inputs
- fun_sig_info.num_fwd_inputs_with_fuel_no_state
- fun_sig_info.num_fwd_inputs_with_fuel_with_state
- in
- let _, back_inputs =
- Collections.List.split_at fun_sig.inputs
- fun_sig_info.num_fwd_inputs_with_fuel_with_state
- in
- List.concat [ fuel; fwd_inputs; state; back_inputs ]
- in
+ let loop_sig_info =
+ let fuel = if !Config.use_fuel then 1 else 0 in
+ let num_inputs = List.length loop.inputs in
+ let num_fwd_inputs_with_fuel_no_state = fuel + num_inputs in
+ let fwd_state =
+ fun_sig_info.num_fwd_inputs_with_fuel_with_state
+ - fun_sig_info.num_fwd_inputs_with_fuel_no_state
+ in
+ let num_fwd_inputs_with_fuel_with_state =
+ num_fwd_inputs_with_fuel_no_state + fwd_state
+ in
+ {
+ has_fuel = !Config.use_fuel;
+ num_fwd_inputs_with_fuel_no_state;
+ num_fwd_inputs_with_fuel_with_state;
+ num_back_inputs_no_state = fun_sig_info.num_back_inputs_no_state;
+ num_back_inputs_with_state =
+ fun_sig_info.num_back_inputs_with_state;
+ effect_info = loop_effect_info;
+ }
+ in
- let loop_sig =
- {
- type_params = fun_sig.type_params;
- inputs = inputs_tys;
- output = fun_sig.output;
- doutputs = fun_sig.doutputs;
- info = loop_sig_info;
- }
- in
+ let inputs_tys =
+ let fuel = if !Config.use_fuel then [ mk_fuel_ty ] else [] in
+ let fwd_inputs = List.map (fun (v : var) -> v.ty) loop.inputs in
+ let state =
+ Collections.List.subslice fun_sig.inputs
+ fun_sig_info.num_fwd_inputs_with_fuel_no_state
+ fun_sig_info.num_fwd_inputs_with_fuel_with_state
+ in
+ let _, back_inputs =
+ Collections.List.split_at fun_sig.inputs
+ fun_sig_info.num_fwd_inputs_with_fuel_with_state
+ in
+ List.concat [ fuel; fwd_inputs; state; back_inputs ]
+ in
- let fuel_vars, inputs, inputs_lvs =
- (* Introduce the fuel input *)
- let fuel_vars, fuel0_var, fuel_lvs =
- if !Config.use_fuel then
- let fuel0_var = mk_fuel_var loop.fuel0 in
- let fuel_lvs = mk_typed_pattern_from_var fuel0_var None in
- (Some (loop.fuel0, loop.fuel), [ fuel0_var ], [ fuel_lvs ])
- else (None, [], [])
- in
+ let loop_sig =
+ {
+ type_params = fun_sig.type_params;
+ inputs = inputs_tys;
+ output = fun_sig.output;
+ doutputs = fun_sig.doutputs;
+ info = loop_sig_info;
+ }
+ in
- (* Introduce the forward input state *)
- let fwd_state_var, fwd_state_lvs =
- assert (loop_effect_info.stateful = Option.is_some loop.input_state);
- match loop.input_state with
- | None -> ([], [])
- | Some input_state ->
- let state_var = mk_state_var input_state in
- let state_lvs = mk_typed_pattern_from_var state_var None in
- ([ state_var ], [ state_lvs ])
- in
+ let fuel_vars, inputs, inputs_lvs =
+ (* Introduce the fuel input *)
+ let fuel_vars, fuel0_var, fuel_lvs =
+ if !Config.use_fuel then
+ let fuel0_var = mk_fuel_var loop.fuel0 in
+ let fuel_lvs = mk_typed_pattern_from_var fuel0_var None in
+ (Some (loop.fuel0, loop.fuel), [ fuel0_var ], [ fuel_lvs ])
+ else (None, [], [])
+ in
- (* Introduce the additional backward inputs *)
- let fun_body = Option.get def.body in
- let _, back_inputs =
- Collections.List.split_at fun_body.inputs
- fun_sig_info.num_fwd_inputs_with_fuel_with_state
- in
- let _, back_inputs_lvs =
- Collections.List.split_at fun_body.inputs_lvs
- fun_sig_info.num_fwd_inputs_with_fuel_with_state
- in
+ (* Introduce the forward input state *)
+ let fwd_state_var, fwd_state_lvs =
+ assert (
+ loop_effect_info.stateful = Option.is_some loop.input_state);
+ match loop.input_state with
+ | None -> ([], [])
+ | Some input_state ->
+ let state_var = mk_state_var input_state in
+ let state_lvs = mk_typed_pattern_from_var state_var None in
+ ([ state_var ], [ state_lvs ])
+ in
- let inputs =
- List.concat [ fuel0_var; fwd_state_var; loop.inputs; back_inputs ]
- in
- let inputs_lvs =
- List.concat
- [ fuel_lvs; fwd_state_lvs; loop.inputs_lvs; back_inputs_lvs ]
- in
- (fuel_vars, inputs, inputs_lvs)
- in
+ (* Introduce the additional backward inputs *)
+ let fun_body = Option.get def.body in
+ let _, back_inputs =
+ Collections.List.split_at fun_body.inputs
+ fun_sig_info.num_fwd_inputs_with_fuel_with_state
+ in
+ let _, back_inputs_lvs =
+ Collections.List.split_at fun_body.inputs_lvs
+ fun_sig_info.num_fwd_inputs_with_fuel_with_state
+ in
- (* Wrap the loop body in a match over the fuel *)
- let loop_body =
- match fuel_vars with
- | None -> loop.loop_body
- | Some (fuel0, fuel) ->
- SymbolicToPure.wrap_in_match_fuel fuel0 fuel loop.loop_body
- in
+ let inputs =
+ List.concat
+ [ fuel0_var; fwd_state_var; loop.inputs; back_inputs ]
+ in
+ let inputs_lvs =
+ List.concat
+ [ fuel_lvs; fwd_state_lvs; loop.inputs_lvs; back_inputs_lvs ]
+ in
+ (fuel_vars, inputs, inputs_lvs)
+ in
- let loop_body = { inputs; inputs_lvs; body = loop_body } in
+ (* Wrap the loop body in a match over the fuel *)
+ let loop_body =
+ match fuel_vars with
+ | None -> loop.loop_body
+ | Some (fuel0, fuel) ->
+ SymbolicToPure.wrap_in_match_fuel fuel0 fuel loop.loop_body
+ in
- let loop_def =
- {
- def_id = def.def_id;
- num_loops = 0;
- loop_id = Some loop.loop_id;
- back_id = def.back_id;
- basename = def.basename;
- signature = loop_sig;
- is_global_decl_body = def.is_global_decl_body;
- body = Some loop_body;
- }
- in
- (* Store the loop definition *)
- loops := LoopId.Map.add_strict loop.loop_id loop_def !loops;
+ let loop_body = { inputs; inputs_lvs; body = loop_body } in
+
+ let loop_def =
+ {
+ def_id = def.def_id;
+ num_loops;
+ loop_id = Some loop.loop_id;
+ back_id = def.back_id;
+ basename = def.basename;
+ signature = loop_sig;
+ is_global_decl_body = def.is_global_decl_body;
+ body = Some loop_body;
+ }
+ in
+ (* Store the loop definition *)
+ loops := LoopId.Map.add_strict loop.loop_id loop_def !loops;
- (* Update the current expression to remove the [Loop] node, and continue *)
- (self#visit_texpression env loop.fun_end).e
- end
- in
+ (* Update the current expression to remove the [Loop] node, and continue *)
+ (self#visit_texpression env loop.fun_end).e
+ end
+ in
- match def.body with
- | None -> (def, [])
- | Some body ->
let body_expr = expr_visitor#visit_texpression () body.body in
let body = { body with body = body_expr } in
- let def = { def with body = Some body } in
+ let def = { def with body = Some body; num_loops } in
let loops = List.map snd (LoopId.Map.bindings !loops) in
(def, loops)