summaryrefslogtreecommitdiff
path: root/compiler
diff options
context:
space:
mode:
authorSon Ho2022-12-17 10:27:12 +0100
committerSon HO2023-02-03 11:21:46 +0100
commit66638a2a96c7639553a340917b87e26d94265c5e (patch)
treea0219df7582ca17784135345924790dc26a7e315 /compiler
parent07621dcf488eef1c4a4ab797c21cc34ab474d225 (diff)
Fix various issues with the generation of code for the loops
Diffstat (limited to '')
-rw-r--r--compiler/Extract.ml29
-rw-r--r--compiler/ExtractBase.ml21
-rw-r--r--compiler/Pure.ml5
-rw-r--r--compiler/PureMicroPasses.ml383
-rw-r--r--compiler/PureUtils.ml15
-rw-r--r--compiler/SymbolicToPure.ml23
-rw-r--r--compiler/Translate.ml85
-rw-r--r--compiler/TranslateCore.ml4
8 files changed, 416 insertions, 149 deletions
diff --git a/compiler/Extract.ml b/compiler/Extract.ml
index fa384de6..b3d7b49e 100644
--- a/compiler/Extract.ml
+++ b/compiler/Extract.ml
@@ -1254,21 +1254,28 @@ let extract_state_type (fmt : F.formatter) (ctx : extraction_ctx)
(forward function and backward functions).
*)
let extract_fun_decl_register_names (ctx : extraction_ctx) (keep_fwd : bool)
- (has_decreases_clause : bool) (def : pure_fun_translation) : extraction_ctx
- =
- let fwd, back_ls = def in
- (* Register the decrease clause, if necessary *)
- let ctx =
- if has_decreases_clause then ctx_add_decrases_clause fwd ctx else ctx
+ (has_decreases_clause : fun_decl -> bool) (def : pure_fun_translation) :
+ extraction_ctx =
+ let (fwd, loop_fwds), back_ls = def in
+ (* Register the decrease clauses, if necessary *)
+ let register_decreases ctx def =
+ if has_decreases_clause def then ctx_add_decreases_clause def ctx else ctx
in
- (* Register the forward function name *)
- let ctx = ctx_add_fun_decl (keep_fwd, def) fwd ctx in
+ let ctx = List.fold_left register_decreases ctx (fwd :: loop_fwds) in
+ (* Register the function names *)
+ let register_fun ctx f = ctx_add_fun_decl (keep_fwd, def) f ctx in
+ let register_funs ctx fl = List.fold_left register_fun ctx fl in
+ (* Register the forward functions' names *)
+ let ctx = register_funs ctx (fwd :: loop_fwds) in
(* Register the backward functions' names *)
let ctx =
List.fold_left
- (fun ctx back -> ctx_add_fun_decl (keep_fwd, def) back ctx)
+ (fun ctx (back, loop_backs) ->
+ let ctx = register_fun ctx back in
+ register_funs ctx loop_backs)
ctx back_ls
in
+
(* Return *)
ctx
@@ -1855,7 +1862,7 @@ let extract_template_decreases_clause (ctx : extraction_ctx) (fmt : F.formatter)
(def : fun_decl) : unit =
assert (!backend = FStar);
(* Retrieve the function name *)
- let def_name = ctx_get_decreases_clause def.def_id ctx in
+ let def_name = ctx_get_decreases_clause def.def_id def.loop_id ctx in
(* Add a break before *)
F.pp_print_break fmt 0 0;
(* Print a comment to link the extracted type to its original rust definition *)
@@ -1992,7 +1999,7 @@ let extract_fun_decl (ctx : extraction_ctx) (fmt : F.formatter)
(* Open a box for the decreases term *)
F.pp_open_hovbox fmt ctx.indent_incr;
(* The name of the decrease clause *)
- let decr_name = ctx_get_decreases_clause def.def_id ctx in
+ let decr_name = ctx_get_decreases_clause def.def_id def.loop_id ctx in
F.pp_print_string fmt decr_name;
(* Print the type parameters *)
List.iter
diff --git a/compiler/ExtractBase.ml b/compiler/ExtractBase.ml
index c1ea536a..b952d555 100644
--- a/compiler/ExtractBase.ml
+++ b/compiler/ExtractBase.ml
@@ -273,7 +273,7 @@ type formatter = {
type id =
| GlobalId of A.GlobalDeclId.id
| FunId of fun_id
- | DecreasesClauseId of A.fun_id
+ | DecreasesClauseId of (A.fun_id * LoopId.id option)
(** The definition which provides the decreases/termination clause.
We insert calls to this clause to prove/reason about termination:
the body of those clauses must be defined by the user, in the
@@ -467,14 +467,19 @@ let id_to_string (id : id) (ctx : extraction_ctx) : string =
in
"fun name (" ^ lp_kind ^ fwd_back_kind ^ "): " ^ fun_name
| Pure fid -> PrintPure.pure_assumed_fun_id_to_string fid)
- | DecreasesClauseId fid ->
+ | DecreasesClauseId (fid, lid) ->
let fun_name =
match fid with
| Regular fid ->
Print.fun_name_to_string (A.FunDeclId.Map.find fid fun_decls).name
| Assumed aid -> A.show_assumed_fun_id aid
in
- "decreases clause for function: " ^ fun_name
+ let loop =
+ match lid with
+ | None -> ""
+ | Some lid -> ", loop: " ^ LoopId.to_string lid
+ in
+ "decreases clause for function: " ^ fun_name ^ loop
| TypeId id -> "type name: " ^ get_type_name id
| StructId id -> "struct constructor of: " ^ get_type_name id
| VariantId (id, variant_id) ->
@@ -581,9 +586,9 @@ let ctx_get_variant (def_id : type_id) (variant_id : VariantId.id)
(ctx : extraction_ctx) : string =
ctx_get (VariantId (def_id, variant_id)) ctx
-let ctx_get_decreases_clause (def_id : A.FunDeclId.id) (ctx : extraction_ctx) :
- string =
- ctx_get (DecreasesClauseId (Regular def_id)) ctx
+let ctx_get_decreases_clause (def_id : A.FunDeclId.id)
+ (loop_id : LoopId.id option) (ctx : extraction_ctx) : string =
+ ctx_get (DecreasesClauseId (Regular def_id, loop_id)) ctx
(** Generate a unique type variable name and add it to the context *)
let ctx_add_type_var (basename : string) (id : TypeVarId.id)
@@ -669,10 +674,10 @@ let ctx_add_struct (def : type_decl) (ctx : extraction_ctx) :
let ctx = ctx_add (StructId (AdtId def.def_id)) name ctx in
(ctx, name)
-let ctx_add_decrases_clause (def : fun_decl) (ctx : extraction_ctx) :
+let ctx_add_decreases_clause (def : fun_decl) (ctx : extraction_ctx) :
extraction_ctx =
let name = ctx.fmt.decreases_clause_name def.def_id def.basename in
- ctx_add (DecreasesClauseId (Regular def.def_id)) name ctx
+ ctx_add (DecreasesClauseId (Regular def.def_id, def.loop_id)) name ctx
let ctx_add_global_decl_and_body (def : A.global_decl) (ctx : extraction_ctx) :
extraction_ctx =
diff --git a/compiler/Pure.ml b/compiler/Pure.ml
index 6fb20b22..97eced1d 100644
--- a/compiler/Pure.ml
+++ b/compiler/Pure.ml
@@ -14,7 +14,7 @@ module SymbolicValueId = V.SymbolicValueId
module FunDeclId = A.FunDeclId
module GlobalDeclId = A.GlobalDeclId
-(** We redefine identifiers for loop: in {Values}, the identifiers are global
+(** We redefine identifiers for loop: in {!Values}, the identifiers are global
(they monotonically increase across functions) while in {!module:Pure} we want
the indices to start at 0 for every function.
*)
@@ -492,6 +492,9 @@ and match_branch = { pat : typed_pattern; branch : texpression }
and loop = {
fun_end : texpression;
loop_id : loop_id;
+ fuel0 : var_id;
+ fuel : var_id;
+ input_state : var_id option;
inputs : var list;
inputs_lvs : typed_pattern list;
(** The inputs seen as patterns. See {!fun_body}. *)
diff --git a/compiler/PureMicroPasses.ml b/compiler/PureMicroPasses.ml
index 87ab4609..335336be 100644
--- a/compiler/PureMicroPasses.ml
+++ b/compiler/PureMicroPasses.ml
@@ -432,12 +432,34 @@ let compute_pretty_names (def : fun_decl) : fun_decl =
(ctx, Switch (scrut, body))
(* *)
and update_loop (loop : loop) (ctx : pn_ctx) : pn_ctx * expression =
- let { fun_end; loop_id; inputs; inputs_lvs; loop_body } = loop in
+ let {
+ fun_end;
+ loop_id;
+ fuel0;
+ fuel;
+ input_state;
+ inputs;
+ inputs_lvs;
+ loop_body;
+ } =
+ loop
+ in
let ctx, fun_end = update_texpression fun_end ctx in
let ctx, loop_body = update_texpression loop_body ctx in
let inputs = List.map (fun input -> update_var ctx input None) inputs in
let inputs_lvs = List.map (update_typed_pattern ctx) inputs_lvs in
- let loop = { fun_end; loop_id; inputs; inputs_lvs; loop_body } in
+ let loop =
+ {
+ fun_end;
+ loop_id;
+ fuel0;
+ fuel;
+ input_state;
+ inputs;
+ inputs_lvs;
+ loop_body;
+ }
+ in
(ctx, Loop loop)
(* *)
and update_meta (meta : meta) (e : texpression) (ctx : pn_ctx) :
@@ -972,6 +994,160 @@ let filter_if_backward_with_no_outputs (def : fun_decl) : fun_decl option =
then None
else Some def
+(** Retrieve the loop definitions from the function definition.
+
+ {!SymbolicToPure} generates an AST in which the loop bodies are part of
+ the function body (see the {!Pure.Loop} node). This function extracts
+ those function bodies into independent definitions while removing
+ occurrences of the {!Pure.Loop} node.
+ *)
+let decompose_loops (def : fun_decl) : fun_decl * fun_decl list =
+ (* Store the loops here *)
+ let loops = ref LoopId.Map.empty in
+ let expr_visitor =
+ object (self)
+ inherit [_] map_expression
+
+ method! visit_Loop env loop =
+ let fun_sig = def.signature in
+ let fun_sig_info = fun_sig.info in
+ let fun_effect_info = fun_sig_info.effect_info in
+
+ (* Generate the loop definition *)
+ let loop_effect_info =
+ {
+ stateful_group = fun_effect_info.stateful_group;
+ stateful = fun_effect_info.stateful;
+ can_fail = fun_effect_info.can_fail;
+ can_diverge = fun_effect_info.can_diverge;
+ is_rec = fun_effect_info.is_rec;
+ }
+ in
+
+ let loop_sig_info =
+ let fuel = if !Config.use_fuel then 1 else 0 in
+ let num_inputs = List.length loop.inputs in
+ let num_fwd_inputs_with_fuel_no_state = fuel + num_inputs in
+ let num_fwd_inputs_with_fuel_with_state =
+ fun_sig_info.num_fwd_inputs_with_fuel_with_state
+ - fun_sig_info.num_fwd_inputs_with_fuel_no_state
+ in
+ {
+ has_fuel = !Config.use_fuel;
+ num_fwd_inputs_with_fuel_no_state;
+ num_fwd_inputs_with_fuel_with_state;
+ num_back_inputs_no_state = fun_sig_info.num_back_inputs_no_state;
+ num_back_inputs_with_state = fun_sig_info.num_back_inputs_with_state;
+ effect_info = loop_effect_info;
+ }
+ in
+
+ let inputs_tys =
+ let fuel = if !Config.use_fuel then [ mk_fuel_ty ] else [] in
+ let fwd_inputs = List.map (fun (v : var) -> v.ty) loop.inputs in
+ let state =
+ Collections.List.subslice fun_sig.inputs
+ fun_sig_info.num_fwd_inputs_with_fuel_no_state
+ fun_sig_info.num_fwd_inputs_with_fuel_with_state
+ in
+ let _, back_inputs =
+ Collections.List.split_at fun_sig.inputs
+ fun_sig_info.num_fwd_inputs_with_fuel_with_state
+ in
+ List.concat [ fuel; fwd_inputs; state; back_inputs ]
+ in
+
+ let loop_sig =
+ {
+ type_params = fun_sig.type_params;
+ inputs = inputs_tys;
+ output = fun_sig.output;
+ doutputs = fun_sig.doutputs;
+ info = loop_sig_info;
+ }
+ in
+
+ let fuel_vars, inputs, inputs_lvs =
+ (* Introduce the fuel input *)
+ let fuel_vars, fuel0_var, fuel_lvs =
+ if !Config.use_fuel then
+ let fuel0_var = mk_fuel_var loop.fuel0 in
+ let fuel_lvs = mk_typed_pattern_from_var fuel0_var None in
+ (Some (loop.fuel0, loop.fuel), [ fuel0_var ], [ fuel_lvs ])
+ else (None, [], [])
+ in
+
+ (* Introduce the forward input state *)
+ let fwd_state_var, fwd_state_lvs =
+ assert (loop_effect_info.stateful = Option.is_some loop.input_state);
+ match loop.input_state with
+ | None -> ([], [])
+ | Some input_state ->
+ let state_var = mk_state_var input_state in
+ let state_lvs = mk_typed_pattern_from_var state_var None in
+ ([ state_var ], [ state_lvs ])
+ in
+
+ (* Introduce the additional backward inputs *)
+ let fun_body = Option.get def.body in
+ let _, back_inputs =
+ Collections.List.split_at fun_body.inputs
+ fun_sig_info.num_fwd_inputs_with_fuel_with_state
+ in
+ let _, back_inputs_lvs =
+ Collections.List.split_at fun_body.inputs_lvs
+ fun_sig_info.num_fwd_inputs_with_fuel_with_state
+ in
+
+ let inputs =
+ List.concat [ fuel0_var; fwd_state_var; loop.inputs; back_inputs ]
+ in
+ let inputs_lvs =
+ List.concat
+ [ fuel_lvs; fwd_state_lvs; loop.inputs_lvs; back_inputs_lvs ]
+ in
+ (fuel_vars, inputs, inputs_lvs)
+ in
+
+ (* Wrap the loop body in a match over the fuel *)
+ let loop_body =
+ match fuel_vars with
+ | None -> loop.loop_body
+ | Some (fuel0, fuel) ->
+ SymbolicToPure.wrap_in_match_fuel fuel0 fuel loop.loop_body
+ in
+
+ let loop_body = { inputs; inputs_lvs; body = loop_body } in
+
+ let loop_def =
+ {
+ def_id = def.def_id;
+ num_loops = 0;
+ loop_id = Some loop.loop_id;
+ back_id = def.back_id;
+ basename = def.basename;
+ signature = loop_sig;
+ is_global_decl_body = def.is_global_decl_body;
+ body = Some loop_body;
+ }
+ in
+ (* Store the loop definition *)
+ loops := LoopId.Map.add_strict loop.loop_id loop_def !loops;
+
+ (* Update the current expression to remove the [Loop] node, and continue *)
+ (self#visit_texpression env loop.fun_end).e
+ end
+ in
+
+ match def.body with
+ | None -> (def, [])
+ | Some body ->
+ let body_expr = expr_visitor#visit_texpression () body.body in
+ let body = { body with body = body_expr } in
+ let def = { def with body = Some body } in
+ let loops = List.map snd (LoopId.Map.bindings !loops) in
+ (def, loops)
+
(** Return [false] if the forward function is useless and should be filtered.
- a forward function with no output (comes from a Rust function with
@@ -989,7 +1165,7 @@ let filter_if_backward_with_no_outputs (def : fun_decl) : fun_decl option =
altogether.
*)
let keep_forward (trans : pure_fun_translation) : bool =
- let fwd, backs = trans in
+ let (fwd, _), backs = trans in
(* Note that at this point, the output types are no longer seen as tuples:
* they should be lists of length 1. *)
if
@@ -1306,13 +1482,110 @@ let unfold_monadic_let_bindings (_ctx : trans_ctx) (def : fun_decl) : fun_decl =
(* Return *)
{ def with body = Some body }
+(** Auxiliary function for {!apply_passes_to_def} *)
+let apply_end_passes_to_def (ctx : trans_ctx) (def : fun_decl) : fun_decl =
+ (* Convert the unit variables to [()] if they are used as right-values or
+ * [_] if they are used as left values. *)
+ let def = unit_vars_to_unit def in
+ log#ldebug
+ (lazy ("unit_vars_to_unit:\n\n" ^ fun_decl_to_string ctx def ^ "\n"));
+
+ (* Inline the useless variable reassignments *)
+ let inline_named_vars = true in
+ let inline_pure = true in
+ let def =
+ inline_useless_var_reassignments inline_named_vars inline_pure def
+ in
+ log#ldebug
+ (lazy
+ ("inline_useless_var_assignments:\n\n" ^ fun_decl_to_string ctx def ^ "\n"));
+
+ (* Eliminate the box functions - note that the "box" types were eliminated
+ * during the symbolic to pure phase: see the comments for [eliminate_box_functions] *)
+ let def = eliminate_box_functions ctx def in
+ log#ldebug
+ (lazy ("eliminate_box_functions:\n\n" ^ fun_decl_to_string ctx def ^ "\n"));
+
+ (* Filter the useless variables, assignments, function calls, etc. *)
+ let def = filter_useless !Config.filter_useless_monadic_calls ctx def in
+ log#ldebug (lazy ("filter_useless:\n\n" ^ fun_decl_to_string ctx def ^ "\n"));
+
+ (* Simplify the aggregated ADTs.
+
+ Ex.:
+ {[
+ (* type struct = { f0 : nat; f1 : nat } *)
+
+ Mkstruct x.f0 x.f1 ~~> x
+ ]}
+ *)
+ let def = simplify_aggregates ctx def in
+ log#ldebug
+ (lazy ("simplify_aggregates:\n\n" ^ fun_decl_to_string ctx def ^ "\n"));
+
+ (* Decompose the monadic let-bindings - used by Coq *)
+ let def =
+ if !Config.decompose_monadic_let_bindings then (
+ let def = decompose_monadic_let_bindings ctx def in
+ log#ldebug
+ (lazy
+ ("decompose_monadic_let_bindings:\n\n" ^ fun_decl_to_string ctx def
+ ^ "\n"));
+ def)
+ else (
+ log#ldebug
+ (lazy
+ "ignoring decompose_monadic_let_bindings due to the configuration\n");
+ def)
+ in
+
+ (* Decompose nested let-patterns *)
+ let def =
+ if !Config.decompose_nested_let_patterns then (
+ let def = decompose_nested_let_patterns ctx def in
+ log#ldebug
+ (lazy
+ ("decompose_nested_let_patterns:\n\n" ^ fun_decl_to_string ctx def
+ ^ "\n"));
+ def)
+ else (
+ log#ldebug
+ (lazy
+ "ignoring decompose_nested_let_patterns due to the configuration\n");
+ def)
+ in
+
+ (* Unfold the monadic let-bindings *)
+ let def =
+ if !Config.unfold_monadic_let_bindings then (
+ let def = unfold_monadic_let_bindings ctx def in
+ log#ldebug
+ (lazy
+ ("unfold_monadic_let_bindings:\n\n" ^ fun_decl_to_string ctx def
+ ^ "\n"));
+ def)
+ else (
+ log#ldebug
+ (lazy "ignoring unfold_monadic_let_bindings due to the configuration\n");
+ def)
+ in
+
+ (* We are done *)
+ def
+
(** Apply all the micro-passes to a function.
+ As loops are initially directly integrated into the function definition,
+ {!apply_passes_to_def} extracts those loops definitions from the body;
+ it thus returns the pair: (function def, loop defs). See {!decompose_loops}
+ for more information.
+
Will return [None] if the function is a backward function with no outputs.
[ctx]: used only for printing.
*)
-let apply_passes_to_def (ctx : trans_ctx) (def : fun_decl) : fun_decl option =
+let apply_passes_to_def (ctx : trans_ctx) (def : fun_decl) :
+ (fun_decl * fun_decl list) option =
(* Debug *)
log#ldebug
(lazy
@@ -1347,101 +1620,19 @@ let apply_passes_to_def (ctx : trans_ctx) (def : fun_decl) : fun_decl option =
match def with
| None -> None
| Some def ->
- (* Convert the unit variables to [()] if they are used as right-values or
- * [_] if they are used as left values. *)
- let def = unit_vars_to_unit def in
- log#ldebug
- (lazy ("unit_vars_to_unit:\n\n" ^ fun_decl_to_string ctx def ^ "\n"));
-
- (* Inline the useless variable reassignments *)
- let inline_named_vars = true in
- let inline_pure = true in
- let def =
- inline_useless_var_reassignments inline_named_vars inline_pure def
- in
- log#ldebug
- (lazy
- ("inline_useless_var_assignments:\n\n" ^ fun_decl_to_string ctx def
- ^ "\n"));
-
- (* Eliminate the box functions - note that the "box" types were eliminated
- * during the symbolic to pure phase: see the comments for [eliminate_box_functions] *)
- let def = eliminate_box_functions ctx def in
- log#ldebug
- (lazy
- ("eliminate_box_functions:\n\n" ^ fun_decl_to_string ctx def ^ "\n"));
-
- (* Filter the useless variables, assignments, function calls, etc. *)
- let def = filter_useless !Config.filter_useless_monadic_calls ctx def in
- log#ldebug
- (lazy ("filter_useless:\n\n" ^ fun_decl_to_string ctx def ^ "\n"));
+ (* Extract the loop definitions by removing the {!Loop} node *)
+ let def, loops = decompose_loops def in
- (* Simplify the aggregated ADTs.
-
- Ex.:
- {[
- (* type struct = { f0 : nat; f1 : nat } *)
-
- Mkstruct x.f0 x.f1 ~~> x
- ]}
- *)
- let def = simplify_aggregates ctx def in
- log#ldebug
- (lazy ("simplify_aggregates:\n\n" ^ fun_decl_to_string ctx def ^ "\n"));
-
- (* Decompose the monadic let-bindings - used by Coq *)
- let def =
- if !Config.decompose_monadic_let_bindings then (
- let def = decompose_monadic_let_bindings ctx def in
- log#ldebug
- (lazy
- ("decompose_monadic_let_bindings:\n\n"
- ^ fun_decl_to_string ctx def ^ "\n"));
- def)
- else (
- log#ldebug
- (lazy
- "ignoring decompose_monadic_let_bindings due to the configuration\n");
- def)
- in
-
- (* Decompose nested let-patterns *)
- let def =
- if !Config.decompose_nested_let_patterns then (
- let def = decompose_nested_let_patterns ctx def in
- log#ldebug
- (lazy
- ("decompose_nested_let_patterns:\n\n" ^ fun_decl_to_string ctx def
- ^ "\n"));
- def)
- else (
- log#ldebug
- (lazy
- "ignoring decompose_nested_let_patterns due to the configuration\n");
- def)
- in
-
- (* Unfold the monadic let-bindings *)
- let def =
- if !Config.unfold_monadic_let_bindings then (
- let def = unfold_monadic_let_bindings ctx def in
- log#ldebug
- (lazy
- ("unfold_monadic_let_bindings:\n\n" ^ fun_decl_to_string ctx def
- ^ "\n"));
- def)
- else (
- log#ldebug
- (lazy
- "ignoring unfold_monadic_let_bindings due to the configuration\n");
- def)
- in
-
- (* We are done *)
- Some def
+ (* Apply the remaining passes *)
+ let def = apply_end_passes_to_def ctx def in
+ let loops = List.map (apply_end_passes_to_def ctx) loops in
+ Some (def, loops)
(** Return the forward/backward translations on which we applied the micro-passes.
+ This function also extracts the loop definitions from the function body
+ (see {!decompose_loops}).
+
Also returns a boolean indicating whether the forward function should be kept
or not (because useful/useless - [true] means we need to keep the forward
function).
@@ -1450,7 +1641,7 @@ let apply_passes_to_def (ctx : trans_ctx) (def : fun_decl) : fun_decl option =
functions: keeping it is not necessary but more convenient.
*)
let apply_passes_to_pure_fun_translation (ctx : trans_ctx)
- (trans : pure_fun_translation) : bool * pure_fun_translation =
+ (trans : fun_decl * fun_decl list) : bool * pure_fun_translation =
(* Apply the passes to the individual functions *)
let forward, backwards = trans in
let forward = Option.get (apply_passes_to_def ctx forward) in
diff --git a/compiler/PureUtils.ml b/compiler/PureUtils.ml
index b5c9b686..e1421f5a 100644
--- a/compiler/PureUtils.ml
+++ b/compiler/PureUtils.ml
@@ -7,6 +7,9 @@ let log = Logging.pure_utils_log
type regular_fun_id = A.fun_id * T.RegionGroupId.id option
[@@deriving show, ord]
+(** We use this type as a key for lookups *)
+type fun_loop_id = A.FunDeclId.id * LoopId.id option [@@deriving show, ord]
+
module RegularFunIdOrderedType = struct
type t = regular_fun_id
@@ -30,6 +33,18 @@ end
module FunOrOpIdMap = Collections.MakeMap (FunOrOpIdOrderedType)
module FunOrOpIdSet = Collections.MakeSet (FunOrOpIdOrderedType)
+module FunLoopIdOrderedType = struct
+ type t = fun_loop_id
+
+ let compare = compare_fun_loop_id
+ let to_string = show_fun_loop_id
+ let pp_t = pp_fun_loop_id
+ let show_t = show_fun_loop_id
+end
+
+module FunLoopIdMap = Collections.MakeMap (FunLoopIdOrderedType)
+module FunLoopIdSet = Collections.MakeSet (FunLoopIdOrderedType)
+
let dest_arrow_ty (ty : ty) : ty * ty =
match ty with
| Arrow (arg_ty, ret_ty) -> (arg_ty, ret_ty)
diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml
index a2b41165..ad603bd5 100644
--- a/compiler/SymbolicToPure.ml
+++ b/compiler/SymbolicToPure.ml
@@ -2261,7 +2261,19 @@ and translate_loop (loop : S.loop) (ctx : bs_ctx) : texpression =
let loop_body = translate_expression loop.loop_expr ctx_loop in
(* Create the loop node and return *)
- let loop = Loop { fun_end; loop_id; inputs; inputs_lvs; loop_body } in
+ let loop =
+ Loop
+ {
+ fun_end;
+ loop_id;
+ fuel0 = ctx.fuel0;
+ fuel = ctx.fuel;
+ input_state = (if !Config.use_state then Some ctx.state_var else None);
+ inputs;
+ inputs_lvs;
+ loop_body;
+ }
+ in
assert (fun_end.ty = loop_body.ty);
let ty = fun_end.ty in
{ e = loop; ty }
@@ -2282,10 +2294,11 @@ and translate_meta (meta : S.meta) (e : S.expression) (ctx : bs_ctx) :
{ e; ty }
(** Wrap a function body in a match over the fuel to control termination. *)
-let wrap_in_match_fuel (body : texpression) (ctx : bs_ctx) : texpression =
- let fuel0_var : var = mk_fuel_var ctx.fuel0 in
+let wrap_in_match_fuel (fuel0 : VarId.id) (fuel : VarId.id) (body : texpression)
+ : texpression =
+ let fuel0_var : var = mk_fuel_var fuel0 in
let fuel0 = mk_texpression_from_var fuel0_var in
- let nfuel_var : var = mk_fuel_var ctx.fuel in
+ let nfuel_var : var = mk_fuel_var fuel in
let nfuel_pat = mk_typed_pattern_from_var nfuel_var None in
let fail_branch =
mk_result_fail_texpression_with_error_id error_out_of_fuel_id body.ty
@@ -2376,7 +2389,7 @@ let translate_fun_decl (ctx : bs_ctx) (body : S.expression option) : fun_decl =
(* Add a match over the fuel, if necessary *)
let body =
if function_decreases_fuel effect_info then
- wrap_in_match_fuel body ctx
+ wrap_in_match_fuel ctx.fuel0 ctx.fuel body
else body
in
(* Sanity check *)
diff --git a/compiler/Translate.ml b/compiler/Translate.ml
index 32c32ac4..10a37770 100644
--- a/compiler/Translate.ml
+++ b/compiler/Translate.ml
@@ -51,7 +51,7 @@ let translate_function_to_symbolics (trans_ctx : trans_ctx) (fdef : A.fun_decl)
let translate_function_to_pure (trans_ctx : trans_ctx)
(fun_sigs : SymbolicToPure.fun_sig_named_outputs RegularFunIdMap.t)
(pure_type_decls : Pure.type_decl Pure.TypeDeclId.Map.t) (fdef : A.fun_decl)
- : pure_fun_translation =
+ : pure_fun_translation_no_loops =
(* Debug *)
log#ldebug
(lazy
@@ -213,7 +213,8 @@ let translate_function_to_pure (trans_ctx : trans_ctx)
sg.info.num_fwd_inputs_with_fuel_with_state
in
let num_back_inputs = Option.get sg.info.num_back_inputs_no_state in
- Collections.List.subslice sg.inputs num_forward_inputs num_back_inputs
+ Collections.List.subslice sg.inputs num_forward_inputs
+ (num_forward_inputs + num_back_inputs)
in
(* As we forbid nested borrows, the additional inputs for the backward
* functions come from the borrows in the return value of the rust function:
@@ -336,7 +337,7 @@ type gen_ctx = {
extract_ctx : ExtractBase.extraction_ctx;
trans_types : Pure.type_decl Pure.TypeDeclId.Map.t;
trans_funs : (bool * pure_fun_translation) A.FunDeclId.Map.t;
- functions_with_decreases_clause : A.FunDeclId.Set.t;
+ functions_with_decreases_clause : PureUtils.FunLoopIdSet.t;
}
type gen_config = {
@@ -370,7 +371,7 @@ let module_has_opaque_decls (ctx : gen_ctx) : bool * bool =
in
let has_opaque_funs =
A.FunDeclId.Map.exists
- (fun _ ((_, (t_fwd, _)) : bool * pure_fun_translation) ->
+ (fun _ ((_, ((t_fwd, _), _)) : bool * pure_fun_translation) ->
Option.is_none t_fwd.body)
ctx.trans_funs
in
@@ -452,10 +453,11 @@ let export_global (fmt : Format.formatter) (config : gen_config) (ctx : gen_ctx)
(id : A.GlobalDeclId.id) : unit =
let global_decls = ctx.extract_ctx.trans_ctx.global_context.global_decls in
let global = A.GlobalDeclId.Map.find id global_decls in
- let _, (body, body_backs) =
+ let _, ((body, loop_fwds), body_backs) =
A.FunDeclId.Map.find global.body_id ctx.trans_funs
in
- assert (List.length body_backs = 0);
+ assert (body_backs = []);
+ assert (loop_fwds = []);
let is_opaque = Option.is_none body.Pure.body in
if
@@ -487,7 +489,8 @@ let export_functions_declarations (fmt : Format.formatter) (config : gen_config)
(ctx : gen_ctx) (is_rec : bool) (decls : Pure.fun_decl list) : unit =
(* Utility to check a function has a decrease clause *)
let has_decreases_clause (def : Pure.fun_decl) : bool =
- A.FunDeclId.Set.mem def.def_id ctx.functions_with_decreases_clause
+ PureUtils.FunLoopIdSet.mem (def.def_id, def.loop_id)
+ ctx.functions_with_decreases_clause
in
(* Extract the function declarations *)
@@ -532,16 +535,21 @@ let export_functions_group (fmt : Format.formatter) (config : gen_config)
(ctx : gen_ctx) (pure_ls : (bool * pure_fun_translation) list) : unit =
(* Utility to check a function has a decrease clause *)
let has_decreases_clause (def : Pure.fun_decl) : bool =
- A.FunDeclId.Set.mem def.def_id ctx.functions_with_decreases_clause
+ PureUtils.FunLoopIdSet.mem (def.def_id, def.loop_id)
+ ctx.functions_with_decreases_clause
in
(* Extract the decrease clauses template bodies *)
if config.extract_template_decreases_clauses then
List.iter
- (fun (_, (fwd, _)) ->
- let has_decr_clause = has_decreases_clause fwd in
- if has_decr_clause then
- Extract.extract_template_decreases_clause ctx.extract_ctx fmt fwd)
+ (fun (_, ((fwd, loop_fwds), _)) ->
+ let extract_decrease decl =
+ let has_decr_clause = has_decreases_clause decl in
+ if has_decr_clause then
+ Extract.extract_template_decreases_clause ctx.extract_ctx fmt decl
+ in
+ extract_decrease fwd;
+ List.iter extract_decrease loop_fwds)
pure_ls;
(* Concatenate the function definitions, filtering the useless forward
@@ -549,8 +557,15 @@ let export_functions_group (fmt : Format.formatter) (config : gen_config)
let decls =
List.concat
(List.map
- (fun (keep_fwd, (fwd, back_ls)) ->
- if keep_fwd then fwd :: back_ls else back_ls)
+ (fun (keep_fwd, ((fwd, fwd_loops), (back_ls : fun_and_loops list))) ->
+ let fwd = if keep_fwd then List.append fwd_loops [ fwd ] else [] in
+ let back : Pure.fun_decl list =
+ List.concat
+ (List.map
+ (fun (back, loop_backs) -> List.append loop_backs [ back ])
+ back_ls)
+ in
+ List.append fwd back)
pure_ls)
in
@@ -568,7 +583,7 @@ let export_functions_group (fmt : Format.formatter) (config : gen_config)
(* Insert unit tests if necessary *)
if config.test_trans_unit_functions then
List.iter
- (fun (keep_fwd, (fwd, _)) ->
+ (fun (keep_fwd, ((fwd, _), _)) ->
if keep_fwd then
Extract.extract_unit_test_if_unit_fun ctx.extract_ctx fmt fwd)
pure_ls
@@ -721,12 +736,25 @@ let translate_module (filename : string) (dest_dir : string) (crate : A.crate) :
(* We need to compute which functions are recursive, in order to know
* whether we should generate a decrease clause or not. *)
let rec_functions =
- A.FunDeclId.Set.of_list
- (List.concat
- (List.map
- (fun decl -> match decl with A.Fun (Rec ids) -> ids | _ -> [])
- crate.declarations))
+ List.map
+ (fun (_, ((fwd, loop_fwds), _)) ->
+ let fwd =
+ if fwd.Pure.signature.info.effect_info.is_rec then
+ [ (fwd.def_id, None) ]
+ else []
+ in
+ let loop_fwds =
+ List.map
+ (fun (def : Pure.fun_decl) -> [ (def.def_id, def.loop_id) ])
+ loop_fwds
+ in
+ fwd :: loop_fwds)
+ trans_funs
+ in
+ let rec_functions : PureUtils.fun_loop_id list =
+ List.concat (List.concat rec_functions)
in
+ let rec_functions = PureUtils.FunLoopIdSet.of_list rec_functions in
(* Register unique names for all the top-level types, globals and functions.
* Note that the order in which we generate the names doesn't matter:
@@ -740,18 +768,21 @@ let translate_module (filename : string) (dest_dir : string) (crate : A.crate) :
let ctx =
List.fold_left
- (fun ctx (keep_fwd, def) ->
+ (fun ctx (keep_fwd, defs) ->
(* We generate a decrease clause for all the recursive functions *)
- let gen_decr_clause =
- A.FunDeclId.Set.mem (fst def).Pure.def_id rec_functions
+ let fwd_def = fst (fst defs) in
+ let gen_decr_clause (def : Pure.fun_decl) =
+ PureUtils.FunLoopIdSet.mem
+ (def.Pure.def_id, def.Pure.loop_id)
+ rec_functions
in
(* Register the names, only if the function is not a global body -
* those are handled later *)
- let is_global = (fst def).Pure.is_global_decl_body in
+ let is_global = fwd_def.Pure.is_global_decl_body in
if is_global then ctx
else
Extract.extract_fun_decl_register_names ctx keep_fwd gen_decr_clause
- def)
+ defs)
ctx trans_funs
in
@@ -785,7 +816,7 @@ let translate_module (filename : string) (dest_dir : string) (crate : A.crate) :
A.FunDeclId.Map.of_list
(List.map
(fun ((keep_fwd, (fd, bdl)) : bool * pure_fun_translation) ->
- (fd.def_id, (keep_fwd, (fd, bdl))))
+ ((fst fd).def_id, (keep_fwd, (fd, bdl))))
trans_funs)
in
@@ -883,7 +914,7 @@ let translate_module (filename : string) (dest_dir : string) (crate : A.crate) :
(* Extract the template clauses *)
let needs_clauses_module =
!Config.extract_decreases_clauses
- && not (A.FunDeclId.Set.is_empty rec_functions)
+ && not (PureUtils.FunLoopIdSet.is_empty rec_functions)
in
(if needs_clauses_module && !Config.extract_template_decreases_clauses then
let ext = match !Config.backend with FStar -> ".fst" | Coq -> ".v" in
diff --git a/compiler/TranslateCore.ml b/compiler/TranslateCore.ml
index a658147d..9ba73c7e 100644
--- a/compiler/TranslateCore.ml
+++ b/compiler/TranslateCore.ml
@@ -26,7 +26,9 @@ type trans_ctx = {
global_context : global_context;
}
-type pure_fun_translation = Pure.fun_decl * Pure.fun_decl list
+type fun_and_loops = Pure.fun_decl * Pure.fun_decl list
+type pure_fun_translation_no_loops = Pure.fun_decl * Pure.fun_decl list
+type pure_fun_translation = fun_and_loops * fun_and_loops list
let type_decl_to_string (ctx : trans_ctx) (def : Pure.type_decl) : string =
let type_params = def.type_params in