summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--compiler/Extract.ml7
-rw-r--r--compiler/ExtractBase.ml71
-rw-r--r--compiler/PrintPure.ml6
-rw-r--r--compiler/Pure.ml11
-rw-r--r--compiler/PureMicroPasses.ml2
-rw-r--r--compiler/ReorderDecls.ml2
-rw-r--r--compiler/SymbolicToPure.ml7
7 files changed, 69 insertions, 37 deletions
diff --git a/compiler/Extract.ml b/compiler/Extract.ml
index fd44fab4..fbfcadfd 100644
--- a/compiler/Extract.ml
+++ b/compiler/Extract.ml
@@ -417,11 +417,12 @@ let mk_formatter (ctx : trans_ctx) (crate_name : string)
let parts = List.map to_snake_case (get_name name) in
String.concat "_" parts
in
- let fun_name (fname : fun_name) (num_rgs : int)
- (rg : region_group_info option) (filter_info : bool * int) : string =
+ let fun_name (fname : fun_name) (num_loops : int) (loop_id : LoopId.id option)
+ (num_rgs : int) (rg : region_group_info option) (filter_info : bool * int)
+ : string =
let fname = fun_name_to_snake_case fname in
(* Compute the suffix *)
- let suffix = default_fun_suffix num_rgs rg filter_info in
+ let suffix = default_fun_suffix num_loops loop_id num_rgs rg filter_info in
(* Concatenate *)
fname ^ suffix
in
diff --git a/compiler/ExtractBase.ml b/compiler/ExtractBase.ml
index 06c71236..c1ea536a 100644
--- a/compiler/ExtractBase.ml
+++ b/compiler/ExtractBase.ml
@@ -162,13 +162,22 @@ type formatter = {
global_name : global_name -> string;
(** Provided a basename, compute a global name. *)
fun_name :
- fun_name -> int -> region_group_info option -> bool * int -> string;
+ fun_name ->
+ int ->
+ LoopId.id option ->
+ int ->
+ region_group_info option ->
+ bool * int ->
+ string;
(** Compute the name of a regular (non-assumed) function.
Inputs:
- - function id
- function basename (TODO: shouldn't appear for assumed functions?...)
- - number of region groups
+ - number of loops in the function (useful to check if we need to use
+ indices to derive unique names for the loops for instance - if there is
+ exactly one loop, we don't need to use indices)
+ - loop id (if pertinent)
+ - number of region groups (same comment as for the number of loops)
- region group information in case of a backward function
([None] if forward function)
- pair:
@@ -448,7 +457,7 @@ let id_to_string (id : id) (ctx : extraction_ctx) : string =
let lp_kind =
match lp_id with
| None -> ""
- | Some lp_id -> "loop " ^ V.LoopId.to_string lp_id ^ ", "
+ | Some lp_id -> "loop " ^ LoopId.to_string lp_id ^ ", "
in
let fwd_back_kind =
@@ -541,7 +550,7 @@ let ctx_get_global (id : A.GlobalDeclId.id) (ctx : extraction_ctx) : string =
let ctx_get_function (id : fun_id) (ctx : extraction_ctx) : string =
ctx_get (FunId id) ctx
-let ctx_get_local_function (id : A.FunDeclId.id) (lp : V.LoopId.id option)
+let ctx_get_local_function (id : A.FunDeclId.id) (lp : LoopId.id option)
(rg : RegionGroupId.id option) (ctx : extraction_ctx) : string =
ctx_get_function (FromLlbc (Regular id, lp, rg)) ctx
@@ -704,7 +713,8 @@ let ctx_add_fun_decl (trans_group : bool * pure_fun_translation)
Some { id = rg_id; region_names }
in
let name =
- ctx.fmt.fun_name def.basename num_rgs rg_info (keep_fwd, num_backs)
+ ctx.fmt.fun_name def.basename def.num_loops def.loop_id num_rgs rg_info
+ (keep_fwd, num_backs)
in
ctx_add
(FunId (FromLlbc (A.Regular def_id, def.loop_id, def.back_id)))
@@ -788,7 +798,8 @@ let compute_type_decl_name (fmt : formatter) (def : type_decl) : string =
information.
TODO: move all those helpers.
*)
-let default_fun_suffix (num_region_groups : int) (rg : region_group_info option)
+let default_fun_suffix (num_loops : int) (loop_id : LoopId.id option)
+ (num_region_groups : int) (rg : region_group_info option)
((keep_fwd, num_backs) : bool * int) : string =
(* There are several cases:
- [rg] is [Some]: this is a forward function:
@@ -805,21 +816,31 @@ let default_fun_suffix (num_region_groups : int) (rg : region_group_info option)
we could not add the "_fwd" suffix) to prevent name clashes between
definitions (in particular between type and function definitions).
*)
- match rg with
- | None -> "_fwd"
- | Some rg ->
- assert (num_region_groups > 0 && num_backs > 0);
- if num_backs = 1 then
- (* Exactly one backward function *)
- if not keep_fwd then "_fwd_back" else "_back"
- else if
- (* Several region groups/backward functions:
- - if all the regions in the group have names, we use those names
- - otherwise we use an index
- *)
- List.for_all Option.is_some rg.region_names
- then
- (* Concatenate the region names *)
- "_back" ^ String.concat "" (List.map Option.get rg.region_names)
- else (* Use the region index *)
- "_back" ^ RegionGroupId.to_string rg.id
+ let lp_suff =
+ match loop_id with
+ | None -> ""
+ | Some loop_id ->
+ if num_loops = 1 then "_loop" else "_loop" ^ LoopId.to_string loop_id
+ in
+
+ let rg_suff =
+ match rg with
+ | None -> "_fwd"
+ | Some rg ->
+ assert (num_region_groups > 0 && num_backs > 0);
+ if num_backs = 1 then
+ (* Exactly one backward function *)
+ if not keep_fwd then "_fwd_back" else "_back"
+ else if
+ (* Several region groups/backward functions:
+ - if all the regions in the group have names, we use those names
+ - otherwise we use an index
+ *)
+ List.for_all Option.is_some rg.region_names
+ then
+ (* Concatenate the region names *)
+ "_back" ^ String.concat "" (List.map Option.get rg.region_names)
+ else (* Use the region index *)
+ "_back" ^ RegionGroupId.to_string rg.id
+ in
+ lp_suff ^ rg_suff
diff --git a/compiler/PrintPure.ml b/compiler/PrintPure.ml
index 3113347c..152e29c0 100644
--- a/compiler/PrintPure.ml
+++ b/compiler/PrintPure.ml
@@ -415,12 +415,12 @@ let inst_fun_sig_to_string (fmt : ast_formatter) (sg : inst_fun_sig) : string =
let all_types = List.append inputs [ output ] in
String.concat " -> " all_types
-let fun_suffix (lp_id : V.LoopId.id option) (rg_id : T.RegionGroupId.id option)
- : string =
+let fun_suffix (lp_id : LoopId.id option) (rg_id : T.RegionGroupId.id option) :
+ string =
let lp_suff =
match lp_id with
| None -> ""
- | Some lp_id -> "^loop^" ^ V.LoopId.to_string lp_id
+ | Some lp_id -> "^loop^" ^ LoopId.to_string lp_id
in
let rg_suff =
diff --git a/compiler/Pure.ml b/compiler/Pure.ml
index 2578273d..9972d539 100644
--- a/compiler/Pure.ml
+++ b/compiler/Pure.ml
@@ -14,6 +14,12 @@ module SymbolicValueId = V.SymbolicValueId
module FunDeclId = A.FunDeclId
module GlobalDeclId = A.GlobalDeclId
+(** We redefine identifiers for loop: in {Values}, the identifiers are global
+ (they monotonically increase across functions) while in {!Pure} we want
+ the indices to start at 0 for every function.
+ *)
+module LoopId = IdGen ()
+
(** We give an identifier to every phase of the synthesis (forward, backward
for group of regions 0, etc.) *)
module SynthPhaseId = IdGen ()
@@ -306,7 +312,7 @@ type pure_assumed_fun_id =
(** A function identifier *)
type fun_id =
- | FromLlbc of A.fun_id * V.LoopId.id option * T.RegionGroupId.id option
+ | FromLlbc of A.fun_id * LoopId.id option * T.RegionGroupId.id option
(** A function coming from LLBC.
The loop id is [None] if the function is actually the auxiliary function
@@ -627,7 +633,8 @@ type fun_body = {
type fun_decl = {
def_id : FunDeclId.id;
- loop_id : V.LoopId.id option;
+ num_loops : int;
+ loop_id : LoopId.id option;
(** [Some] if this definition was generated for a loop *)
back_id : T.RegionGroupId.id option;
basename : fun_name;
diff --git a/compiler/PureMicroPasses.ml b/compiler/PureMicroPasses.ml
index 2e4a534e..a27b9d95 100644
--- a/compiler/PureMicroPasses.ml
+++ b/compiler/PureMicroPasses.ml
@@ -621,7 +621,7 @@ let inline_useless_var_reassignments (inline_named : bool) (inline_pure : bool)
In this situation, we can remove the call [f@fwd x].
*)
let expression_contains_child_call_in_all_paths (ctx : trans_ctx)
- (id0 : A.fun_id) (lp_id0 : V.LoopId.id option)
+ (id0 : A.fun_id) (lp_id0 : LoopId.id option)
(rg_id0 : T.RegionGroupId.id option) (tys0 : ty list)
(args0 : texpression list) (e : texpression) : bool =
let check_call (fun_id1 : fun_or_op_id) (tys1 : ty list)
diff --git a/compiler/ReorderDecls.ml b/compiler/ReorderDecls.ml
index 4f02eb81..9e15da4e 100644
--- a/compiler/ReorderDecls.ml
+++ b/compiler/ReorderDecls.ml
@@ -8,7 +8,7 @@ let log = Logging.reorder_decls_log
type fun_id = {
def_id : FunDeclId.id;
- lp_id : V.LoopId.id option;
+ lp_id : LoopId.id option;
rg_id : T.RegionGroupId.id option;
}
[@@deriving show, ord]
diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml
index bba3326b..dd662074 100644
--- a/compiler/SymbolicToPure.ml
+++ b/compiler/SymbolicToPure.ml
@@ -2028,15 +2028,18 @@ let translate_fun_decl (ctx : bs_ctx) (body : S.expression option) : fun_decl =
Some { inputs; inputs_lvs; body }
in
- (* Note that for now, the loops are still *inside* the function body: we will
- extract them from there later, in {!PureMicroPasses} (by "splitting" the definition).
+ (* Note that for now, the loops are still *inside* the function body (and we
+ haven't counted them): we will extract them from there later, in {!PureMicroPasses}
+ (by "splitting" the definition).
*)
+ let num_loops = 0 in
let loop_id = None in
(* Assemble the declaration *)
let def =
{
def_id;
+ num_loops;
loop_id;
back_id = bid;
basename;