diff options
Diffstat (limited to 'compiler')
-rw-r--r-- | compiler/PureMicroPasses.ml | 281 |
1 files changed, 149 insertions, 132 deletions
diff --git a/compiler/PureMicroPasses.ml b/compiler/PureMicroPasses.ml index 3937db0a..ae791135 100644 --- a/compiler/PureMicroPasses.ml +++ b/compiler/PureMicroPasses.ml @@ -1049,152 +1049,169 @@ let filter_if_backward_with_no_outputs (def : fun_decl) : fun_decl option = occurrences of the {!Pure.Loop} node. *) let decompose_loops (def : fun_decl) : fun_decl * fun_decl list = - (* Store the loops here *) - let loops = ref LoopId.Map.empty in - let expr_visitor = - object (self) - inherit [_] map_expression - - method! visit_Loop env loop = - let fun_sig = def.signature in - let fun_sig_info = fun_sig.info in - let fun_effect_info = fun_sig_info.effect_info in - - (* Generate the loop definition *) - let loop_effect_info = - { - stateful_group = fun_effect_info.stateful_group; - stateful = fun_effect_info.stateful; - can_fail = fun_effect_info.can_fail; - can_diverge = fun_effect_info.can_diverge; - is_rec = fun_effect_info.is_rec; - } - in + match def.body with + | None -> (def, []) + | Some body -> + (* Count the number of loops *) + let loops = ref LoopId.Set.empty in + let expr_visitor = + object + inherit [_] iter_expression as super + + method! visit_Loop env loop = + loops := LoopId.Set.add loop.loop_id !loops; + super#visit_Loop env loop + end + in + expr_visitor#visit_texpression () body.body; + let num_loops = LoopId.Set.cardinal !loops in - let loop_sig_info = - let fuel = if !Config.use_fuel then 1 else 0 in - let num_inputs = List.length loop.inputs in - let num_fwd_inputs_with_fuel_no_state = fuel + num_inputs in - let fwd_state = - fun_sig_info.num_fwd_inputs_with_fuel_with_state - - fun_sig_info.num_fwd_inputs_with_fuel_no_state - in - let num_fwd_inputs_with_fuel_with_state = - num_fwd_inputs_with_fuel_no_state + fwd_state - in - { - has_fuel = !Config.use_fuel; - num_fwd_inputs_with_fuel_no_state; - num_fwd_inputs_with_fuel_with_state; - num_back_inputs_no_state = fun_sig_info.num_back_inputs_no_state; - num_back_inputs_with_state = fun_sig_info.num_back_inputs_with_state; - effect_info = loop_effect_info; - } - in + (* Store the loops here *) + let loops = ref LoopId.Map.empty in + let expr_visitor = + object (self) + inherit [_] map_expression + + method! visit_Loop env loop = + let fun_sig = def.signature in + let fun_sig_info = fun_sig.info in + let fun_effect_info = fun_sig_info.effect_info in + + (* Generate the loop definition *) + let loop_effect_info = + { + stateful_group = fun_effect_info.stateful_group; + stateful = fun_effect_info.stateful; + can_fail = fun_effect_info.can_fail; + can_diverge = fun_effect_info.can_diverge; + is_rec = fun_effect_info.is_rec; + } + in - let inputs_tys = - let fuel = if !Config.use_fuel then [ mk_fuel_ty ] else [] in - let fwd_inputs = List.map (fun (v : var) -> v.ty) loop.inputs in - let state = - Collections.List.subslice fun_sig.inputs - fun_sig_info.num_fwd_inputs_with_fuel_no_state - fun_sig_info.num_fwd_inputs_with_fuel_with_state - in - let _, back_inputs = - Collections.List.split_at fun_sig.inputs - fun_sig_info.num_fwd_inputs_with_fuel_with_state - in - List.concat [ fuel; fwd_inputs; state; back_inputs ] - in + let loop_sig_info = + let fuel = if !Config.use_fuel then 1 else 0 in + let num_inputs = List.length loop.inputs in + let num_fwd_inputs_with_fuel_no_state = fuel + num_inputs in + let fwd_state = + fun_sig_info.num_fwd_inputs_with_fuel_with_state + - fun_sig_info.num_fwd_inputs_with_fuel_no_state + in + let num_fwd_inputs_with_fuel_with_state = + num_fwd_inputs_with_fuel_no_state + fwd_state + in + { + has_fuel = !Config.use_fuel; + num_fwd_inputs_with_fuel_no_state; + num_fwd_inputs_with_fuel_with_state; + num_back_inputs_no_state = fun_sig_info.num_back_inputs_no_state; + num_back_inputs_with_state = + fun_sig_info.num_back_inputs_with_state; + effect_info = loop_effect_info; + } + in - let loop_sig = - { - type_params = fun_sig.type_params; - inputs = inputs_tys; - output = fun_sig.output; - doutputs = fun_sig.doutputs; - info = loop_sig_info; - } - in + let inputs_tys = + let fuel = if !Config.use_fuel then [ mk_fuel_ty ] else [] in + let fwd_inputs = List.map (fun (v : var) -> v.ty) loop.inputs in + let state = + Collections.List.subslice fun_sig.inputs + fun_sig_info.num_fwd_inputs_with_fuel_no_state + fun_sig_info.num_fwd_inputs_with_fuel_with_state + in + let _, back_inputs = + Collections.List.split_at fun_sig.inputs + fun_sig_info.num_fwd_inputs_with_fuel_with_state + in + List.concat [ fuel; fwd_inputs; state; back_inputs ] + in - let fuel_vars, inputs, inputs_lvs = - (* Introduce the fuel input *) - let fuel_vars, fuel0_var, fuel_lvs = - if !Config.use_fuel then - let fuel0_var = mk_fuel_var loop.fuel0 in - let fuel_lvs = mk_typed_pattern_from_var fuel0_var None in - (Some (loop.fuel0, loop.fuel), [ fuel0_var ], [ fuel_lvs ]) - else (None, [], []) - in + let loop_sig = + { + type_params = fun_sig.type_params; + inputs = inputs_tys; + output = fun_sig.output; + doutputs = fun_sig.doutputs; + info = loop_sig_info; + } + in - (* Introduce the forward input state *) - let fwd_state_var, fwd_state_lvs = - assert (loop_effect_info.stateful = Option.is_some loop.input_state); - match loop.input_state with - | None -> ([], []) - | Some input_state -> - let state_var = mk_state_var input_state in - let state_lvs = mk_typed_pattern_from_var state_var None in - ([ state_var ], [ state_lvs ]) - in + let fuel_vars, inputs, inputs_lvs = + (* Introduce the fuel input *) + let fuel_vars, fuel0_var, fuel_lvs = + if !Config.use_fuel then + let fuel0_var = mk_fuel_var loop.fuel0 in + let fuel_lvs = mk_typed_pattern_from_var fuel0_var None in + (Some (loop.fuel0, loop.fuel), [ fuel0_var ], [ fuel_lvs ]) + else (None, [], []) + in - (* Introduce the additional backward inputs *) - let fun_body = Option.get def.body in - let _, back_inputs = - Collections.List.split_at fun_body.inputs - fun_sig_info.num_fwd_inputs_with_fuel_with_state - in - let _, back_inputs_lvs = - Collections.List.split_at fun_body.inputs_lvs - fun_sig_info.num_fwd_inputs_with_fuel_with_state - in + (* Introduce the forward input state *) + let fwd_state_var, fwd_state_lvs = + assert ( + loop_effect_info.stateful = Option.is_some loop.input_state); + match loop.input_state with + | None -> ([], []) + | Some input_state -> + let state_var = mk_state_var input_state in + let state_lvs = mk_typed_pattern_from_var state_var None in + ([ state_var ], [ state_lvs ]) + in - let inputs = - List.concat [ fuel0_var; fwd_state_var; loop.inputs; back_inputs ] - in - let inputs_lvs = - List.concat - [ fuel_lvs; fwd_state_lvs; loop.inputs_lvs; back_inputs_lvs ] - in - (fuel_vars, inputs, inputs_lvs) - in + (* Introduce the additional backward inputs *) + let fun_body = Option.get def.body in + let _, back_inputs = + Collections.List.split_at fun_body.inputs + fun_sig_info.num_fwd_inputs_with_fuel_with_state + in + let _, back_inputs_lvs = + Collections.List.split_at fun_body.inputs_lvs + fun_sig_info.num_fwd_inputs_with_fuel_with_state + in - (* Wrap the loop body in a match over the fuel *) - let loop_body = - match fuel_vars with - | None -> loop.loop_body - | Some (fuel0, fuel) -> - SymbolicToPure.wrap_in_match_fuel fuel0 fuel loop.loop_body - in + let inputs = + List.concat + [ fuel0_var; fwd_state_var; loop.inputs; back_inputs ] + in + let inputs_lvs = + List.concat + [ fuel_lvs; fwd_state_lvs; loop.inputs_lvs; back_inputs_lvs ] + in + (fuel_vars, inputs, inputs_lvs) + in - let loop_body = { inputs; inputs_lvs; body = loop_body } in + (* Wrap the loop body in a match over the fuel *) + let loop_body = + match fuel_vars with + | None -> loop.loop_body + | Some (fuel0, fuel) -> + SymbolicToPure.wrap_in_match_fuel fuel0 fuel loop.loop_body + in - let loop_def = - { - def_id = def.def_id; - num_loops = 0; - loop_id = Some loop.loop_id; - back_id = def.back_id; - basename = def.basename; - signature = loop_sig; - is_global_decl_body = def.is_global_decl_body; - body = Some loop_body; - } - in - (* Store the loop definition *) - loops := LoopId.Map.add_strict loop.loop_id loop_def !loops; + let loop_body = { inputs; inputs_lvs; body = loop_body } in + + let loop_def = + { + def_id = def.def_id; + num_loops; + loop_id = Some loop.loop_id; + back_id = def.back_id; + basename = def.basename; + signature = loop_sig; + is_global_decl_body = def.is_global_decl_body; + body = Some loop_body; + } + in + (* Store the loop definition *) + loops := LoopId.Map.add_strict loop.loop_id loop_def !loops; - (* Update the current expression to remove the [Loop] node, and continue *) - (self#visit_texpression env loop.fun_end).e - end - in + (* Update the current expression to remove the [Loop] node, and continue *) + (self#visit_texpression env loop.fun_end).e + end + in - match def.body with - | None -> (def, []) - | Some body -> let body_expr = expr_visitor#visit_texpression () body.body in let body = { body with body = body_expr } in - let def = { def with body = Some body } in + let def = { def with body = Some body; num_loops } in let loops = List.map snd (LoopId.Map.bindings !loops) in (def, loops) |