summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--compiler/PureMicroPasses.ml281
-rw-r--r--tests/coq/misc/Loops.v19
-rw-r--r--tests/fstar/misc/Loops.Funs.fst18
3 files changed, 167 insertions, 151 deletions
diff --git a/compiler/PureMicroPasses.ml b/compiler/PureMicroPasses.ml
index 3937db0a..ae791135 100644
--- a/compiler/PureMicroPasses.ml
+++ b/compiler/PureMicroPasses.ml
@@ -1049,152 +1049,169 @@ let filter_if_backward_with_no_outputs (def : fun_decl) : fun_decl option =
occurrences of the {!Pure.Loop} node.
*)
let decompose_loops (def : fun_decl) : fun_decl * fun_decl list =
- (* Store the loops here *)
- let loops = ref LoopId.Map.empty in
- let expr_visitor =
- object (self)
- inherit [_] map_expression
-
- method! visit_Loop env loop =
- let fun_sig = def.signature in
- let fun_sig_info = fun_sig.info in
- let fun_effect_info = fun_sig_info.effect_info in
-
- (* Generate the loop definition *)
- let loop_effect_info =
- {
- stateful_group = fun_effect_info.stateful_group;
- stateful = fun_effect_info.stateful;
- can_fail = fun_effect_info.can_fail;
- can_diverge = fun_effect_info.can_diverge;
- is_rec = fun_effect_info.is_rec;
- }
- in
+ match def.body with
+ | None -> (def, [])
+ | Some body ->
+ (* Count the number of loops *)
+ let loops = ref LoopId.Set.empty in
+ let expr_visitor =
+ object
+ inherit [_] iter_expression as super
+
+ method! visit_Loop env loop =
+ loops := LoopId.Set.add loop.loop_id !loops;
+ super#visit_Loop env loop
+ end
+ in
+ expr_visitor#visit_texpression () body.body;
+ let num_loops = LoopId.Set.cardinal !loops in
- let loop_sig_info =
- let fuel = if !Config.use_fuel then 1 else 0 in
- let num_inputs = List.length loop.inputs in
- let num_fwd_inputs_with_fuel_no_state = fuel + num_inputs in
- let fwd_state =
- fun_sig_info.num_fwd_inputs_with_fuel_with_state
- - fun_sig_info.num_fwd_inputs_with_fuel_no_state
- in
- let num_fwd_inputs_with_fuel_with_state =
- num_fwd_inputs_with_fuel_no_state + fwd_state
- in
- {
- has_fuel = !Config.use_fuel;
- num_fwd_inputs_with_fuel_no_state;
- num_fwd_inputs_with_fuel_with_state;
- num_back_inputs_no_state = fun_sig_info.num_back_inputs_no_state;
- num_back_inputs_with_state = fun_sig_info.num_back_inputs_with_state;
- effect_info = loop_effect_info;
- }
- in
+ (* Store the loops here *)
+ let loops = ref LoopId.Map.empty in
+ let expr_visitor =
+ object (self)
+ inherit [_] map_expression
+
+ method! visit_Loop env loop =
+ let fun_sig = def.signature in
+ let fun_sig_info = fun_sig.info in
+ let fun_effect_info = fun_sig_info.effect_info in
+
+ (* Generate the loop definition *)
+ let loop_effect_info =
+ {
+ stateful_group = fun_effect_info.stateful_group;
+ stateful = fun_effect_info.stateful;
+ can_fail = fun_effect_info.can_fail;
+ can_diverge = fun_effect_info.can_diverge;
+ is_rec = fun_effect_info.is_rec;
+ }
+ in
- let inputs_tys =
- let fuel = if !Config.use_fuel then [ mk_fuel_ty ] else [] in
- let fwd_inputs = List.map (fun (v : var) -> v.ty) loop.inputs in
- let state =
- Collections.List.subslice fun_sig.inputs
- fun_sig_info.num_fwd_inputs_with_fuel_no_state
- fun_sig_info.num_fwd_inputs_with_fuel_with_state
- in
- let _, back_inputs =
- Collections.List.split_at fun_sig.inputs
- fun_sig_info.num_fwd_inputs_with_fuel_with_state
- in
- List.concat [ fuel; fwd_inputs; state; back_inputs ]
- in
+ let loop_sig_info =
+ let fuel = if !Config.use_fuel then 1 else 0 in
+ let num_inputs = List.length loop.inputs in
+ let num_fwd_inputs_with_fuel_no_state = fuel + num_inputs in
+ let fwd_state =
+ fun_sig_info.num_fwd_inputs_with_fuel_with_state
+ - fun_sig_info.num_fwd_inputs_with_fuel_no_state
+ in
+ let num_fwd_inputs_with_fuel_with_state =
+ num_fwd_inputs_with_fuel_no_state + fwd_state
+ in
+ {
+ has_fuel = !Config.use_fuel;
+ num_fwd_inputs_with_fuel_no_state;
+ num_fwd_inputs_with_fuel_with_state;
+ num_back_inputs_no_state = fun_sig_info.num_back_inputs_no_state;
+ num_back_inputs_with_state =
+ fun_sig_info.num_back_inputs_with_state;
+ effect_info = loop_effect_info;
+ }
+ in
- let loop_sig =
- {
- type_params = fun_sig.type_params;
- inputs = inputs_tys;
- output = fun_sig.output;
- doutputs = fun_sig.doutputs;
- info = loop_sig_info;
- }
- in
+ let inputs_tys =
+ let fuel = if !Config.use_fuel then [ mk_fuel_ty ] else [] in
+ let fwd_inputs = List.map (fun (v : var) -> v.ty) loop.inputs in
+ let state =
+ Collections.List.subslice fun_sig.inputs
+ fun_sig_info.num_fwd_inputs_with_fuel_no_state
+ fun_sig_info.num_fwd_inputs_with_fuel_with_state
+ in
+ let _, back_inputs =
+ Collections.List.split_at fun_sig.inputs
+ fun_sig_info.num_fwd_inputs_with_fuel_with_state
+ in
+ List.concat [ fuel; fwd_inputs; state; back_inputs ]
+ in
- let fuel_vars, inputs, inputs_lvs =
- (* Introduce the fuel input *)
- let fuel_vars, fuel0_var, fuel_lvs =
- if !Config.use_fuel then
- let fuel0_var = mk_fuel_var loop.fuel0 in
- let fuel_lvs = mk_typed_pattern_from_var fuel0_var None in
- (Some (loop.fuel0, loop.fuel), [ fuel0_var ], [ fuel_lvs ])
- else (None, [], [])
- in
+ let loop_sig =
+ {
+ type_params = fun_sig.type_params;
+ inputs = inputs_tys;
+ output = fun_sig.output;
+ doutputs = fun_sig.doutputs;
+ info = loop_sig_info;
+ }
+ in
- (* Introduce the forward input state *)
- let fwd_state_var, fwd_state_lvs =
- assert (loop_effect_info.stateful = Option.is_some loop.input_state);
- match loop.input_state with
- | None -> ([], [])
- | Some input_state ->
- let state_var = mk_state_var input_state in
- let state_lvs = mk_typed_pattern_from_var state_var None in
- ([ state_var ], [ state_lvs ])
- in
+ let fuel_vars, inputs, inputs_lvs =
+ (* Introduce the fuel input *)
+ let fuel_vars, fuel0_var, fuel_lvs =
+ if !Config.use_fuel then
+ let fuel0_var = mk_fuel_var loop.fuel0 in
+ let fuel_lvs = mk_typed_pattern_from_var fuel0_var None in
+ (Some (loop.fuel0, loop.fuel), [ fuel0_var ], [ fuel_lvs ])
+ else (None, [], [])
+ in
- (* Introduce the additional backward inputs *)
- let fun_body = Option.get def.body in
- let _, back_inputs =
- Collections.List.split_at fun_body.inputs
- fun_sig_info.num_fwd_inputs_with_fuel_with_state
- in
- let _, back_inputs_lvs =
- Collections.List.split_at fun_body.inputs_lvs
- fun_sig_info.num_fwd_inputs_with_fuel_with_state
- in
+ (* Introduce the forward input state *)
+ let fwd_state_var, fwd_state_lvs =
+ assert (
+ loop_effect_info.stateful = Option.is_some loop.input_state);
+ match loop.input_state with
+ | None -> ([], [])
+ | Some input_state ->
+ let state_var = mk_state_var input_state in
+ let state_lvs = mk_typed_pattern_from_var state_var None in
+ ([ state_var ], [ state_lvs ])
+ in
- let inputs =
- List.concat [ fuel0_var; fwd_state_var; loop.inputs; back_inputs ]
- in
- let inputs_lvs =
- List.concat
- [ fuel_lvs; fwd_state_lvs; loop.inputs_lvs; back_inputs_lvs ]
- in
- (fuel_vars, inputs, inputs_lvs)
- in
+ (* Introduce the additional backward inputs *)
+ let fun_body = Option.get def.body in
+ let _, back_inputs =
+ Collections.List.split_at fun_body.inputs
+ fun_sig_info.num_fwd_inputs_with_fuel_with_state
+ in
+ let _, back_inputs_lvs =
+ Collections.List.split_at fun_body.inputs_lvs
+ fun_sig_info.num_fwd_inputs_with_fuel_with_state
+ in
- (* Wrap the loop body in a match over the fuel *)
- let loop_body =
- match fuel_vars with
- | None -> loop.loop_body
- | Some (fuel0, fuel) ->
- SymbolicToPure.wrap_in_match_fuel fuel0 fuel loop.loop_body
- in
+ let inputs =
+ List.concat
+ [ fuel0_var; fwd_state_var; loop.inputs; back_inputs ]
+ in
+ let inputs_lvs =
+ List.concat
+ [ fuel_lvs; fwd_state_lvs; loop.inputs_lvs; back_inputs_lvs ]
+ in
+ (fuel_vars, inputs, inputs_lvs)
+ in
- let loop_body = { inputs; inputs_lvs; body = loop_body } in
+ (* Wrap the loop body in a match over the fuel *)
+ let loop_body =
+ match fuel_vars with
+ | None -> loop.loop_body
+ | Some (fuel0, fuel) ->
+ SymbolicToPure.wrap_in_match_fuel fuel0 fuel loop.loop_body
+ in
- let loop_def =
- {
- def_id = def.def_id;
- num_loops = 0;
- loop_id = Some loop.loop_id;
- back_id = def.back_id;
- basename = def.basename;
- signature = loop_sig;
- is_global_decl_body = def.is_global_decl_body;
- body = Some loop_body;
- }
- in
- (* Store the loop definition *)
- loops := LoopId.Map.add_strict loop.loop_id loop_def !loops;
+ let loop_body = { inputs; inputs_lvs; body = loop_body } in
+
+ let loop_def =
+ {
+ def_id = def.def_id;
+ num_loops;
+ loop_id = Some loop.loop_id;
+ back_id = def.back_id;
+ basename = def.basename;
+ signature = loop_sig;
+ is_global_decl_body = def.is_global_decl_body;
+ body = Some loop_body;
+ }
+ in
+ (* Store the loop definition *)
+ loops := LoopId.Map.add_strict loop.loop_id loop_def !loops;
- (* Update the current expression to remove the [Loop] node, and continue *)
- (self#visit_texpression env loop.fun_end).e
- end
- in
+ (* Update the current expression to remove the [Loop] node, and continue *)
+ (self#visit_texpression env loop.fun_end).e
+ end
+ in
- match def.body with
- | None -> (def, [])
- | Some body ->
let body_expr = expr_visitor#visit_texpression () body.body in
let body = { body with body = body_expr } in
- let def = { def with body = Some body } in
+ let def = { def with body = Some body; num_loops } in
let loops = List.map snd (LoopId.Map.bindings !loops) in
(def, loops)
diff --git a/tests/coq/misc/Loops.v b/tests/coq/misc/Loops.v
index 72c47361..8d552b5b 100644
--- a/tests/coq/misc/Loops.v
+++ b/tests/coq/misc/Loops.v
@@ -7,20 +7,19 @@ Local Open Scope Primitives_scope.
Module Loops.
(** [loops::sum] *)
-Fixpoint sum_loop0_fwd
- (n : nat) (max : u32) (i : u32) (s : u32) : result u32 :=
+Fixpoint sum_loop_fwd (n : nat) (max : u32) (i : u32) (s : u32) : result u32 :=
match n with
| O => Fail_ OutOfFuel
| S n0 =>
if i s< max
- then (s0 <- u32_add s i; i0 <- u32_add i 1%u32; sum_loop0_fwd n0 max i0 s0)
+ then (s0 <- u32_add s i; i0 <- u32_add i 1%u32; sum_loop_fwd n0 max i0 s0)
else u32_mul s 2%u32
end
.
(** [loops::sum] *)
Definition sum_fwd (n : nat) (max : u32) : result u32 :=
- sum_loop0_fwd n max (0%u32) (0%u32)
+ sum_loop_fwd n max (0%u32) (0%u32)
.
(** [loops::List] *)
@@ -33,7 +32,7 @@ Arguments ListCons {T} _ _.
Arguments ListNil {T}.
(** [loops::list_nth_mut_loop] *)
-Fixpoint list_nth_mut_loop_loop0_fwd
+Fixpoint list_nth_mut_loop_loop_fwd
(T : Type) (n : nat) (ls : List_t T) (i : u32) : result T :=
match n with
| O => Fail_ OutOfFuel
@@ -42,7 +41,7 @@ Fixpoint list_nth_mut_loop_loop0_fwd
| ListCons x tl =>
if i s= 0%u32
then Return x
- else (i0 <- u32_sub i 1%u32; list_nth_mut_loop_loop0_fwd T n0 tl i0)
+ else (i0 <- u32_sub i 1%u32; list_nth_mut_loop_loop_fwd T n0 tl i0)
| ListNil => Fail_ Failure
end
end
@@ -51,11 +50,11 @@ Fixpoint list_nth_mut_loop_loop0_fwd
(** [loops::list_nth_mut_loop] *)
Definition list_nth_mut_loop_fwd
(T : Type) (n : nat) (ls : List_t T) (i : u32) : result T :=
- list_nth_mut_loop_loop0_fwd T n ls i
+ list_nth_mut_loop_loop_fwd T n ls i
.
(** [loops::list_nth_mut_loop] *)
-Fixpoint list_nth_mut_loop_loop0_back
+Fixpoint list_nth_mut_loop_loop_back
(T : Type) (n : nat) (ls : List_t T) (i : u32) (ret : T) :
result (List_t T)
:=
@@ -68,7 +67,7 @@ Fixpoint list_nth_mut_loop_loop0_back
then Return (ListCons ret tl)
else (
i0 <- u32_sub i 1%u32;
- l <- list_nth_mut_loop_loop0_back T n0 tl i0 ret;
+ l <- list_nth_mut_loop_loop_back T n0 tl i0 ret;
Return (ListCons x l))
| ListNil => Fail_ Failure
end
@@ -80,7 +79,7 @@ Definition list_nth_mut_loop_back
(T : Type) (n : nat) (ls : List_t T) (i : u32) (ret : T) :
result (List_t T)
:=
- list_nth_mut_loop_loop0_back T n ls i ret
+ list_nth_mut_loop_loop_back T n ls i ret
.
End Loops .
diff --git a/tests/fstar/misc/Loops.Funs.fst b/tests/fstar/misc/Loops.Funs.fst
index a2ae2563..cf05b7f2 100644
--- a/tests/fstar/misc/Loops.Funs.fst
+++ b/tests/fstar/misc/Loops.Funs.fst
@@ -8,7 +8,7 @@ include Loops.Clauses
#set-options "--z3rlimit 50 --fuel 1 --ifuel 1"
(** [loops::sum] *)
-let rec sum_loop0_fwd
+let rec sum_loop_fwd
(max : u32) (i : u32) (s : u32) :
Tot (result u32) (decreases (sum_decreases max i s))
=
@@ -19,16 +19,16 @@ let rec sum_loop0_fwd
| Return s0 ->
begin match u32_add i 1 with
| Fail e -> Fail e
- | Return i0 -> sum_loop0_fwd max i0 s0
+ | Return i0 -> sum_loop_fwd max i0 s0
end
end
else u32_mul s 2
(** [loops::sum] *)
-let sum_fwd (max : u32) : result u32 = sum_loop0_fwd max 0 0
+let sum_fwd (max : u32) : result u32 = sum_loop_fwd max 0 0
(** [loops::list_nth_mut_loop] *)
-let rec list_nth_mut_loop_loop0_fwd
+let rec list_nth_mut_loop_loop_fwd
(t : Type0) (ls : list_t t) (i : u32) :
Tot (result t) (decreases (list_nth_mut_loop_decreases t ls i))
=
@@ -39,17 +39,17 @@ let rec list_nth_mut_loop_loop0_fwd
else
begin match u32_sub i 1 with
| Fail e -> Fail e
- | Return i0 -> list_nth_mut_loop_loop0_fwd t tl i0
+ | Return i0 -> list_nth_mut_loop_loop_fwd t tl i0
end
| ListNil -> Fail Failure
end
(** [loops::list_nth_mut_loop] *)
let list_nth_mut_loop_fwd (t : Type0) (ls : list_t t) (i : u32) : result t =
- list_nth_mut_loop_loop0_fwd t ls i
+ list_nth_mut_loop_loop_fwd t ls i
(** [loops::list_nth_mut_loop] *)
-let rec list_nth_mut_loop_loop0_back
+let rec list_nth_mut_loop_loop_back
(t : Type0) (ls : list_t t) (i : u32) (ret : t) :
Tot (result (list_t t)) (decreases (list_nth_mut_loop_decreases t ls i))
=
@@ -61,7 +61,7 @@ let rec list_nth_mut_loop_loop0_back
begin match u32_sub i 1 with
| Fail e -> Fail e
| Return i0 ->
- begin match list_nth_mut_loop_loop0_back t tl i0 ret with
+ begin match list_nth_mut_loop_loop_back t tl i0 ret with
| Fail e -> Fail e
| Return l -> Return (ListCons x l)
end
@@ -72,5 +72,5 @@ let rec list_nth_mut_loop_loop0_back
(** [loops::list_nth_mut_loop] *)
let list_nth_mut_loop_back
(t : Type0) (ls : list_t t) (i : u32) (ret : t) : result (list_t t) =
- list_nth_mut_loop_loop0_back t ls i ret
+ list_nth_mut_loop_loop_back t ls i ret