From b0454e54744eeedfe2e9e4c8c1dcb592020bb615 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Sat, 17 Dec 2022 14:40:10 +0100 Subject: Improve the loops' numbering --- compiler/PureMicroPasses.ml | 281 +++++++++++++++++++++------------------- tests/coq/misc/Loops.v | 19 ++- tests/fstar/misc/Loops.Funs.fst | 18 +-- 3 files changed, 167 insertions(+), 151 deletions(-) diff --git a/compiler/PureMicroPasses.ml b/compiler/PureMicroPasses.ml index 3937db0a..ae791135 100644 --- a/compiler/PureMicroPasses.ml +++ b/compiler/PureMicroPasses.ml @@ -1049,152 +1049,169 @@ let filter_if_backward_with_no_outputs (def : fun_decl) : fun_decl option = occurrences of the {!Pure.Loop} node. *) let decompose_loops (def : fun_decl) : fun_decl * fun_decl list = - (* Store the loops here *) - let loops = ref LoopId.Map.empty in - let expr_visitor = - object (self) - inherit [_] map_expression - - method! visit_Loop env loop = - let fun_sig = def.signature in - let fun_sig_info = fun_sig.info in - let fun_effect_info = fun_sig_info.effect_info in - - (* Generate the loop definition *) - let loop_effect_info = - { - stateful_group = fun_effect_info.stateful_group; - stateful = fun_effect_info.stateful; - can_fail = fun_effect_info.can_fail; - can_diverge = fun_effect_info.can_diverge; - is_rec = fun_effect_info.is_rec; - } - in + match def.body with + | None -> (def, []) + | Some body -> + (* Count the number of loops *) + let loops = ref LoopId.Set.empty in + let expr_visitor = + object + inherit [_] iter_expression as super + + method! visit_Loop env loop = + loops := LoopId.Set.add loop.loop_id !loops; + super#visit_Loop env loop + end + in + expr_visitor#visit_texpression () body.body; + let num_loops = LoopId.Set.cardinal !loops in - let loop_sig_info = - let fuel = if !Config.use_fuel then 1 else 0 in - let num_inputs = List.length loop.inputs in - let num_fwd_inputs_with_fuel_no_state = fuel + num_inputs in - let fwd_state = - fun_sig_info.num_fwd_inputs_with_fuel_with_state - - fun_sig_info.num_fwd_inputs_with_fuel_no_state - in - let num_fwd_inputs_with_fuel_with_state = - num_fwd_inputs_with_fuel_no_state + fwd_state - in - { - has_fuel = !Config.use_fuel; - num_fwd_inputs_with_fuel_no_state; - num_fwd_inputs_with_fuel_with_state; - num_back_inputs_no_state = fun_sig_info.num_back_inputs_no_state; - num_back_inputs_with_state = fun_sig_info.num_back_inputs_with_state; - effect_info = loop_effect_info; - } - in + (* Store the loops here *) + let loops = ref LoopId.Map.empty in + let expr_visitor = + object (self) + inherit [_] map_expression + + method! visit_Loop env loop = + let fun_sig = def.signature in + let fun_sig_info = fun_sig.info in + let fun_effect_info = fun_sig_info.effect_info in + + (* Generate the loop definition *) + let loop_effect_info = + { + stateful_group = fun_effect_info.stateful_group; + stateful = fun_effect_info.stateful; + can_fail = fun_effect_info.can_fail; + can_diverge = fun_effect_info.can_diverge; + is_rec = fun_effect_info.is_rec; + } + in - let inputs_tys = - let fuel = if !Config.use_fuel then [ mk_fuel_ty ] else [] in - let fwd_inputs = List.map (fun (v : var) -> v.ty) loop.inputs in - let state = - Collections.List.subslice fun_sig.inputs - fun_sig_info.num_fwd_inputs_with_fuel_no_state - fun_sig_info.num_fwd_inputs_with_fuel_with_state - in - let _, back_inputs = - Collections.List.split_at fun_sig.inputs - fun_sig_info.num_fwd_inputs_with_fuel_with_state - in - List.concat [ fuel; fwd_inputs; state; back_inputs ] - in + let loop_sig_info = + let fuel = if !Config.use_fuel then 1 else 0 in + let num_inputs = List.length loop.inputs in + let num_fwd_inputs_with_fuel_no_state = fuel + num_inputs in + let fwd_state = + fun_sig_info.num_fwd_inputs_with_fuel_with_state + - fun_sig_info.num_fwd_inputs_with_fuel_no_state + in + let num_fwd_inputs_with_fuel_with_state = + num_fwd_inputs_with_fuel_no_state + fwd_state + in + { + has_fuel = !Config.use_fuel; + num_fwd_inputs_with_fuel_no_state; + num_fwd_inputs_with_fuel_with_state; + num_back_inputs_no_state = fun_sig_info.num_back_inputs_no_state; + num_back_inputs_with_state = + fun_sig_info.num_back_inputs_with_state; + effect_info = loop_effect_info; + } + in - let loop_sig = - { - type_params = fun_sig.type_params; - inputs = inputs_tys; - output = fun_sig.output; - doutputs = fun_sig.doutputs; - info = loop_sig_info; - } - in + let inputs_tys = + let fuel = if !Config.use_fuel then [ mk_fuel_ty ] else [] in + let fwd_inputs = List.map (fun (v : var) -> v.ty) loop.inputs in + let state = + Collections.List.subslice fun_sig.inputs + fun_sig_info.num_fwd_inputs_with_fuel_no_state + fun_sig_info.num_fwd_inputs_with_fuel_with_state + in + let _, back_inputs = + Collections.List.split_at fun_sig.inputs + fun_sig_info.num_fwd_inputs_with_fuel_with_state + in + List.concat [ fuel; fwd_inputs; state; back_inputs ] + in - let fuel_vars, inputs, inputs_lvs = - (* Introduce the fuel input *) - let fuel_vars, fuel0_var, fuel_lvs = - if !Config.use_fuel then - let fuel0_var = mk_fuel_var loop.fuel0 in - let fuel_lvs = mk_typed_pattern_from_var fuel0_var None in - (Some (loop.fuel0, loop.fuel), [ fuel0_var ], [ fuel_lvs ]) - else (None, [], []) - in + let loop_sig = + { + type_params = fun_sig.type_params; + inputs = inputs_tys; + output = fun_sig.output; + doutputs = fun_sig.doutputs; + info = loop_sig_info; + } + in - (* Introduce the forward input state *) - let fwd_state_var, fwd_state_lvs = - assert (loop_effect_info.stateful = Option.is_some loop.input_state); - match loop.input_state with - | None -> ([], []) - | Some input_state -> - let state_var = mk_state_var input_state in - let state_lvs = mk_typed_pattern_from_var state_var None in - ([ state_var ], [ state_lvs ]) - in + let fuel_vars, inputs, inputs_lvs = + (* Introduce the fuel input *) + let fuel_vars, fuel0_var, fuel_lvs = + if !Config.use_fuel then + let fuel0_var = mk_fuel_var loop.fuel0 in + let fuel_lvs = mk_typed_pattern_from_var fuel0_var None in + (Some (loop.fuel0, loop.fuel), [ fuel0_var ], [ fuel_lvs ]) + else (None, [], []) + in - (* Introduce the additional backward inputs *) - let fun_body = Option.get def.body in - let _, back_inputs = - Collections.List.split_at fun_body.inputs - fun_sig_info.num_fwd_inputs_with_fuel_with_state - in - let _, back_inputs_lvs = - Collections.List.split_at fun_body.inputs_lvs - fun_sig_info.num_fwd_inputs_with_fuel_with_state - in + (* Introduce the forward input state *) + let fwd_state_var, fwd_state_lvs = + assert ( + loop_effect_info.stateful = Option.is_some loop.input_state); + match loop.input_state with + | None -> ([], []) + | Some input_state -> + let state_var = mk_state_var input_state in + let state_lvs = mk_typed_pattern_from_var state_var None in + ([ state_var ], [ state_lvs ]) + in - let inputs = - List.concat [ fuel0_var; fwd_state_var; loop.inputs; back_inputs ] - in - let inputs_lvs = - List.concat - [ fuel_lvs; fwd_state_lvs; loop.inputs_lvs; back_inputs_lvs ] - in - (fuel_vars, inputs, inputs_lvs) - in + (* Introduce the additional backward inputs *) + let fun_body = Option.get def.body in + let _, back_inputs = + Collections.List.split_at fun_body.inputs + fun_sig_info.num_fwd_inputs_with_fuel_with_state + in + let _, back_inputs_lvs = + Collections.List.split_at fun_body.inputs_lvs + fun_sig_info.num_fwd_inputs_with_fuel_with_state + in - (* Wrap the loop body in a match over the fuel *) - let loop_body = - match fuel_vars with - | None -> loop.loop_body - | Some (fuel0, fuel) -> - SymbolicToPure.wrap_in_match_fuel fuel0 fuel loop.loop_body - in + let inputs = + List.concat + [ fuel0_var; fwd_state_var; loop.inputs; back_inputs ] + in + let inputs_lvs = + List.concat + [ fuel_lvs; fwd_state_lvs; loop.inputs_lvs; back_inputs_lvs ] + in + (fuel_vars, inputs, inputs_lvs) + in - let loop_body = { inputs; inputs_lvs; body = loop_body } in + (* Wrap the loop body in a match over the fuel *) + let loop_body = + match fuel_vars with + | None -> loop.loop_body + | Some (fuel0, fuel) -> + SymbolicToPure.wrap_in_match_fuel fuel0 fuel loop.loop_body + in - let loop_def = - { - def_id = def.def_id; - num_loops = 0; - loop_id = Some loop.loop_id; - back_id = def.back_id; - basename = def.basename; - signature = loop_sig; - is_global_decl_body = def.is_global_decl_body; - body = Some loop_body; - } - in - (* Store the loop definition *) - loops := LoopId.Map.add_strict loop.loop_id loop_def !loops; + let loop_body = { inputs; inputs_lvs; body = loop_body } in + + let loop_def = + { + def_id = def.def_id; + num_loops; + loop_id = Some loop.loop_id; + back_id = def.back_id; + basename = def.basename; + signature = loop_sig; + is_global_decl_body = def.is_global_decl_body; + body = Some loop_body; + } + in + (* Store the loop definition *) + loops := LoopId.Map.add_strict loop.loop_id loop_def !loops; - (* Update the current expression to remove the [Loop] node, and continue *) - (self#visit_texpression env loop.fun_end).e - end - in + (* Update the current expression to remove the [Loop] node, and continue *) + (self#visit_texpression env loop.fun_end).e + end + in - match def.body with - | None -> (def, []) - | Some body -> let body_expr = expr_visitor#visit_texpression () body.body in let body = { body with body = body_expr } in - let def = { def with body = Some body } in + let def = { def with body = Some body; num_loops } in let loops = List.map snd (LoopId.Map.bindings !loops) in (def, loops) diff --git a/tests/coq/misc/Loops.v b/tests/coq/misc/Loops.v index 72c47361..8d552b5b 100644 --- a/tests/coq/misc/Loops.v +++ b/tests/coq/misc/Loops.v @@ -7,20 +7,19 @@ Local Open Scope Primitives_scope. Module Loops. (** [loops::sum] *) -Fixpoint sum_loop0_fwd - (n : nat) (max : u32) (i : u32) (s : u32) : result u32 := +Fixpoint sum_loop_fwd (n : nat) (max : u32) (i : u32) (s : u32) : result u32 := match n with | O => Fail_ OutOfFuel | S n0 => if i s< max - then (s0 <- u32_add s i; i0 <- u32_add i 1%u32; sum_loop0_fwd n0 max i0 s0) + then (s0 <- u32_add s i; i0 <- u32_add i 1%u32; sum_loop_fwd n0 max i0 s0) else u32_mul s 2%u32 end . (** [loops::sum] *) Definition sum_fwd (n : nat) (max : u32) : result u32 := - sum_loop0_fwd n max (0%u32) (0%u32) + sum_loop_fwd n max (0%u32) (0%u32) . (** [loops::List] *) @@ -33,7 +32,7 @@ Arguments ListCons {T} _ _. Arguments ListNil {T}. (** [loops::list_nth_mut_loop] *) -Fixpoint list_nth_mut_loop_loop0_fwd +Fixpoint list_nth_mut_loop_loop_fwd (T : Type) (n : nat) (ls : List_t T) (i : u32) : result T := match n with | O => Fail_ OutOfFuel @@ -42,7 +41,7 @@ Fixpoint list_nth_mut_loop_loop0_fwd | ListCons x tl => if i s= 0%u32 then Return x - else (i0 <- u32_sub i 1%u32; list_nth_mut_loop_loop0_fwd T n0 tl i0) + else (i0 <- u32_sub i 1%u32; list_nth_mut_loop_loop_fwd T n0 tl i0) | ListNil => Fail_ Failure end end @@ -51,11 +50,11 @@ Fixpoint list_nth_mut_loop_loop0_fwd (** [loops::list_nth_mut_loop] *) Definition list_nth_mut_loop_fwd (T : Type) (n : nat) (ls : List_t T) (i : u32) : result T := - list_nth_mut_loop_loop0_fwd T n ls i + list_nth_mut_loop_loop_fwd T n ls i . (** [loops::list_nth_mut_loop] *) -Fixpoint list_nth_mut_loop_loop0_back +Fixpoint list_nth_mut_loop_loop_back (T : Type) (n : nat) (ls : List_t T) (i : u32) (ret : T) : result (List_t T) := @@ -68,7 +67,7 @@ Fixpoint list_nth_mut_loop_loop0_back then Return (ListCons ret tl) else ( i0 <- u32_sub i 1%u32; - l <- list_nth_mut_loop_loop0_back T n0 tl i0 ret; + l <- list_nth_mut_loop_loop_back T n0 tl i0 ret; Return (ListCons x l)) | ListNil => Fail_ Failure end @@ -80,7 +79,7 @@ Definition list_nth_mut_loop_back (T : Type) (n : nat) (ls : List_t T) (i : u32) (ret : T) : result (List_t T) := - list_nth_mut_loop_loop0_back T n ls i ret + list_nth_mut_loop_loop_back T n ls i ret . End Loops . diff --git a/tests/fstar/misc/Loops.Funs.fst b/tests/fstar/misc/Loops.Funs.fst index a2ae2563..cf05b7f2 100644 --- a/tests/fstar/misc/Loops.Funs.fst +++ b/tests/fstar/misc/Loops.Funs.fst @@ -8,7 +8,7 @@ include Loops.Clauses #set-options "--z3rlimit 50 --fuel 1 --ifuel 1" (** [loops::sum] *) -let rec sum_loop0_fwd +let rec sum_loop_fwd (max : u32) (i : u32) (s : u32) : Tot (result u32) (decreases (sum_decreases max i s)) = @@ -19,16 +19,16 @@ let rec sum_loop0_fwd | Return s0 -> begin match u32_add i 1 with | Fail e -> Fail e - | Return i0 -> sum_loop0_fwd max i0 s0 + | Return i0 -> sum_loop_fwd max i0 s0 end end else u32_mul s 2 (** [loops::sum] *) -let sum_fwd (max : u32) : result u32 = sum_loop0_fwd max 0 0 +let sum_fwd (max : u32) : result u32 = sum_loop_fwd max 0 0 (** [loops::list_nth_mut_loop] *) -let rec list_nth_mut_loop_loop0_fwd +let rec list_nth_mut_loop_loop_fwd (t : Type0) (ls : list_t t) (i : u32) : Tot (result t) (decreases (list_nth_mut_loop_decreases t ls i)) = @@ -39,17 +39,17 @@ let rec list_nth_mut_loop_loop0_fwd else begin match u32_sub i 1 with | Fail e -> Fail e - | Return i0 -> list_nth_mut_loop_loop0_fwd t tl i0 + | Return i0 -> list_nth_mut_loop_loop_fwd t tl i0 end | ListNil -> Fail Failure end (** [loops::list_nth_mut_loop] *) let list_nth_mut_loop_fwd (t : Type0) (ls : list_t t) (i : u32) : result t = - list_nth_mut_loop_loop0_fwd t ls i + list_nth_mut_loop_loop_fwd t ls i (** [loops::list_nth_mut_loop] *) -let rec list_nth_mut_loop_loop0_back +let rec list_nth_mut_loop_loop_back (t : Type0) (ls : list_t t) (i : u32) (ret : t) : Tot (result (list_t t)) (decreases (list_nth_mut_loop_decreases t ls i)) = @@ -61,7 +61,7 @@ let rec list_nth_mut_loop_loop0_back begin match u32_sub i 1 with | Fail e -> Fail e | Return i0 -> - begin match list_nth_mut_loop_loop0_back t tl i0 ret with + begin match list_nth_mut_loop_loop_back t tl i0 ret with | Fail e -> Fail e | Return l -> Return (ListCons x l) end @@ -72,5 +72,5 @@ let rec list_nth_mut_loop_loop0_back (** [loops::list_nth_mut_loop] *) let list_nth_mut_loop_back (t : Type0) (ls : list_t t) (i : u32) (ret : t) : result (list_t t) = - list_nth_mut_loop_loop0_back t ls i ret + list_nth_mut_loop_loop_back t ls i ret -- cgit v1.2.3