summaryrefslogtreecommitdiff
path: root/compiler
diff options
context:
space:
mode:
authorSon Ho2022-12-14 17:24:37 +0100
committerSon HO2023-02-03 11:21:46 +0100
commitb1e57277baf539f1f009f7c927a1a7445ce6ea45 (patch)
tree227f31c267262a0f4a235b8575e37c65af168673 /compiler
parent54a6b5d1a90b7304817175a33fc37444e559b11e (diff)
Add loop ids to the pure functions identifiers
Diffstat (limited to 'compiler')
-rw-r--r--compiler/Extract.ml14
-rw-r--r--compiler/ExtractBase.ml27
-rw-r--r--compiler/PrintPure.ml27
-rw-r--r--compiler/Pure.ml7
-rw-r--r--compiler/PureMicroPasses.ml15
-rw-r--r--compiler/ReorderDecls.ml12
-rw-r--r--compiler/SymbolicToPure.ml11
-rw-r--r--compiler/Translate.ml4
8 files changed, 80 insertions, 37 deletions
diff --git a/compiler/Extract.ml b/compiler/Extract.ml
index ce0609f5..fd44fab4 100644
--- a/compiler/Extract.ml
+++ b/compiler/Extract.ml
@@ -1907,7 +1907,9 @@ let extract_fun_decl (ctx : extraction_ctx) (fmt : F.formatter)
(kind : decl_kind) (has_decreases_clause : bool) (def : fun_decl) : unit =
assert (not def.is_global_decl_body);
(* Retrieve the function name *)
- let def_name = ctx_get_local_function def.def_id def.back_id ctx in
+ let def_name =
+ ctx_get_local_function def.def_id def.loop_id def.back_id ctx
+ in
(* Add a break before *)
F.pp_print_break fmt 0 0;
(* Print a comment to link the extracted type to its original rust definition *)
@@ -2141,7 +2143,7 @@ let extract_global_decl (ctx : extraction_ctx) (fmt : F.formatter)
let decl_name = ctx_get_global global.def_id ctx in
let body_name =
- ctx_get_function (FromLlbc (Regular global.body_id, None)) ctx
+ ctx_get_function (FromLlbc (Regular global.body_id, None, None)) ctx
in
let decl_ty, body_ty =
@@ -2211,7 +2213,9 @@ let extract_unit_test_if_unit_fun (ctx : extraction_ctx) (fmt : F.formatter)
F.pp_print_string fmt "assert_norm";
F.pp_print_space fmt ();
F.pp_print_string fmt "(";
- let fun_name = ctx_get_local_function def.def_id def.back_id ctx in
+ let fun_name =
+ ctx_get_local_function def.def_id def.loop_id def.back_id ctx
+ in
F.pp_print_string fmt fun_name;
if sg.inputs <> [] then (
F.pp_print_space fmt ();
@@ -2225,7 +2229,9 @@ let extract_unit_test_if_unit_fun (ctx : extraction_ctx) (fmt : F.formatter)
F.pp_print_string fmt "Check";
F.pp_print_space fmt ();
F.pp_print_string fmt "(";
- let fun_name = ctx_get_local_function def.def_id def.back_id ctx in
+ let fun_name =
+ ctx_get_local_function def.def_id def.loop_id def.back_id ctx
+ in
F.pp_print_string fmt fun_name;
if sg.inputs <> [] then (
F.pp_print_space fmt ();
diff --git a/compiler/ExtractBase.ml b/compiler/ExtractBase.ml
index b1901fca..06c71236 100644
--- a/compiler/ExtractBase.ml
+++ b/compiler/ExtractBase.ml
@@ -436,7 +436,7 @@ let id_to_string (id : id) (ctx : extraction_ctx) : string =
"global name: " ^ Print.global_name_to_string name
| FunId fid -> (
match fid with
- | FromLlbc (fid, rg_id) ->
+ | FromLlbc (fid, lp_id, rg_id) ->
let fun_name =
match fid with
| Regular fid ->
@@ -444,12 +444,19 @@ let id_to_string (id : id) (ctx : extraction_ctx) : string =
(A.FunDeclId.Map.find fid fun_decls).name
| Assumed aid -> A.show_assumed_fun_id aid
in
- let fun_kind =
+
+ let lp_kind =
+ match lp_id with
+ | None -> ""
+ | Some lp_id -> "loop " ^ V.LoopId.to_string lp_id ^ ", "
+ in
+
+ let fwd_back_kind =
match rg_id with
| None -> "forward"
| Some rg_id -> "backward " ^ RegionGroupId.to_string rg_id
in
- "fun name (" ^ fun_kind ^ "): " ^ fun_name
+ "fun name (" ^ lp_kind ^ fwd_back_kind ^ "): " ^ fun_name
| Pure fid -> PrintPure.pure_assumed_fun_id_to_string fid)
| DecreasesClauseId fid ->
let fun_name =
@@ -534,9 +541,9 @@ let ctx_get_global (id : A.GlobalDeclId.id) (ctx : extraction_ctx) : string =
let ctx_get_function (id : fun_id) (ctx : extraction_ctx) : string =
ctx_get (FunId id) ctx
-let ctx_get_local_function (id : A.FunDeclId.id) (rg : RegionGroupId.id option)
- (ctx : extraction_ctx) : string =
- ctx_get_function (FromLlbc (Regular id, rg)) ctx
+let ctx_get_local_function (id : A.FunDeclId.id) (lp : V.LoopId.id option)
+ (rg : RegionGroupId.id option) (ctx : extraction_ctx) : string =
+ ctx_get_function (FromLlbc (Regular id, lp, rg)) ctx
let ctx_get_type (id : type_id) (ctx : extraction_ctx) : string =
assert (id <> Tuple);
@@ -662,7 +669,7 @@ let ctx_add_global_decl_and_body (def : A.global_decl) (ctx : extraction_ctx) :
extraction_ctx =
let name = ctx.fmt.global_name def.name in
let decl = GlobalId def.def_id in
- let body = FunId (FromLlbc (Regular def.body_id, None)) in
+ let body = FunId (FromLlbc (Regular def.body_id, None, None)) in
let ctx = ctx_add decl (name ^ "_c") ctx in
let ctx = ctx_add body (name ^ "_body") ctx in
ctx
@@ -699,7 +706,9 @@ let ctx_add_fun_decl (trans_group : bool * pure_fun_translation)
let name =
ctx.fmt.fun_name def.basename num_rgs rg_info (keep_fwd, num_backs)
in
- ctx_add (FunId (FromLlbc (A.Regular def_id, def.back_id))) name ctx
+ ctx_add
+ (FunId (FromLlbc (A.Regular def_id, def.loop_id, def.back_id)))
+ name ctx
type names_map_init = {
keywords : string list;
@@ -760,7 +769,7 @@ let initialize_names_map (fmt : formatter) (init : names_map_init) : names_map =
in
let assumed_functions =
List.map
- (fun (fid, rg, name) -> (FromLlbc (A.Assumed fid, rg), name))
+ (fun (fid, rg, name) -> (FromLlbc (A.Assumed fid, None, rg), name))
init.assumed_llbc_functions
@ List.map (fun (fid, name) -> (Pure fid, name)) init.assumed_pure_functions
in
diff --git a/compiler/PrintPure.ml b/compiler/PrintPure.ml
index 726cc9a0..3113347c 100644
--- a/compiler/PrintPure.ml
+++ b/compiler/PrintPure.ml
@@ -415,10 +415,21 @@ let inst_fun_sig_to_string (fmt : ast_formatter) (sg : inst_fun_sig) : string =
let all_types = List.append inputs [ output ] in
String.concat " -> " all_types
-let fun_suffix (rg_id : T.RegionGroupId.id option) : string =
- match rg_id with
- | None -> ""
- | Some rg_id -> "@" ^ T.RegionGroupId.to_string rg_id
+let fun_suffix (lp_id : V.LoopId.id option) (rg_id : T.RegionGroupId.id option)
+ : string =
+ let lp_suff =
+ match lp_id with
+ | None -> ""
+ | Some lp_id -> "^loop^" ^ V.LoopId.to_string lp_id
+ in
+
+ let rg_suff =
+ match rg_id with
+ | None -> ""
+ | Some rg_id -> "@" ^ T.RegionGroupId.to_string rg_id
+ in
+
+ lp_suff ^ rg_suff
let llbc_assumed_fun_id_to_string (fid : A.assumed_fun_id) : string =
match fid with
@@ -444,13 +455,13 @@ let pure_assumed_fun_id_to_string (fid : pure_assumed_fun_id) : string =
let regular_fun_id_to_string (fmt : ast_formatter) (fun_id : fun_id) : string =
match fun_id with
- | FromLlbc (fid, rg_id) ->
+ | FromLlbc (fid, lp_id, rg_id) ->
let f =
match fid with
| Regular fid -> fmt.fun_decl_id_to_string fid
| Assumed fid -> llbc_assumed_fun_id_to_string fid
in
- f ^ fun_suffix rg_id
+ f ^ fun_suffix lp_id rg_id
| Pure fid -> pure_assumed_fun_id_to_string fid
let unop_to_string (unop : unop) : string =
@@ -620,7 +631,9 @@ and meta_to_string (fmt : ast_formatter) (meta : meta) : string =
let fun_decl_to_string (fmt : ast_formatter) (def : fun_decl) : string =
let type_fmt = ast_to_type_formatter fmt in
- let name = fun_name_to_string def.basename ^ fun_suffix def.back_id in
+ let name =
+ fun_name_to_string def.basename ^ fun_suffix def.loop_id def.back_id
+ in
let signature = fun_sig_to_string fmt def.signature in
match def.body with
| None -> "val " ^ name ^ " :\n " ^ signature
diff --git a/compiler/Pure.ml b/compiler/Pure.ml
index d9d3a404..2578273d 100644
--- a/compiler/Pure.ml
+++ b/compiler/Pure.ml
@@ -306,9 +306,12 @@ type pure_assumed_fun_id =
(** A function identifier *)
type fun_id =
- | FromLlbc of A.fun_id * T.RegionGroupId.id option
+ | FromLlbc of A.fun_id * V.LoopId.id option * T.RegionGroupId.id option
(** A function coming from LLBC.
+ The loop id is [None] if the function is actually the auxiliary function
+ generated from a loop.
+
The region group id is the backward id:: [Some] if the function is a
backward function, [None] if it is a forward function.
*)
@@ -624,6 +627,8 @@ type fun_body = {
type fun_decl = {
def_id : FunDeclId.id;
+ loop_id : V.LoopId.id option;
+ (** [Some] if this definition was generated for a loop *)
back_id : T.RegionGroupId.id option;
basename : fun_name;
(** The "base" name of the function.
diff --git a/compiler/PureMicroPasses.ml b/compiler/PureMicroPasses.ml
index c5eb3c64..2e4a534e 100644
--- a/compiler/PureMicroPasses.ml
+++ b/compiler/PureMicroPasses.ml
@@ -621,15 +621,16 @@ let inline_useless_var_reassignments (inline_named : bool) (inline_pure : bool)
In this situation, we can remove the call [f@fwd x].
*)
let expression_contains_child_call_in_all_paths (ctx : trans_ctx)
- (id0 : A.fun_id) (rg_id0 : T.RegionGroupId.id option) (tys0 : ty list)
+ (id0 : A.fun_id) (lp_id0 : V.LoopId.id option)
+ (rg_id0 : T.RegionGroupId.id option) (tys0 : ty list)
(args0 : texpression list) (e : texpression) : bool =
let check_call (fun_id1 : fun_or_op_id) (tys1 : ty list)
(args1 : texpression list) : bool =
(* Check the fun_ids, to see if call1's function is a child of call0's function *)
match fun_id1 with
- | Fun (FromLlbc (id1, rg_id1)) ->
+ | Fun (FromLlbc (id1, lp_id1, rg_id1)) ->
(* Both are "regular" calls: check if they come from the same rust function *)
- if id0 = id1 then
+ if id0 = id1 && lp_id0 = lp_id1 then
(* Same rust functions: check the regions hierarchy *)
let call1_is_child =
match (rg_id0, rg_id1) with
@@ -801,13 +802,13 @@ let filter_useless (filter_monadic_calls : bool) (ctx : trans_ctx)
* We can filter if the right-expression is a function call,
* under some conditions. *)
match (filter_monadic_calls, opt_destruct_function_call re) with
- | true, Some (Fun (FromLlbc (fid, rg_id)), tys, args) ->
+ | true, Some (Fun (FromLlbc (fid, lp_id, rg_id)), tys, args) ->
(* We need to check if there is a child call - see
* the comments for:
* [expression_contains_child_call_in_all_paths] *)
let has_child_call =
- expression_contains_child_call_in_all_paths ctx fid rg_id
- tys args e
+ expression_contains_child_call_in_all_paths ctx fid lp_id
+ rg_id tys args e
in
if has_child_call then (* Filter *)
(e.e, fun _ -> used)
@@ -1031,7 +1032,7 @@ let eliminate_box_functions (_ctx : trans_ctx) (def : fun_decl) : fun_decl =
match opt_destruct_function_call e with
| Some (fun_id, _tys, args) -> (
match fun_id with
- | Fun (FromLlbc (A.Assumed aid, rg_id)) -> (
+ | Fun (FromLlbc (A.Assumed aid, _lp_id, rg_id)) -> (
(* Below, when dealing with the arguments: we consider the very
* general case, where functions could be boxed (meaning we
* could have: [box_new f x])
diff --git a/compiler/ReorderDecls.ml b/compiler/ReorderDecls.ml
index 2b78c570..4f02eb81 100644
--- a/compiler/ReorderDecls.ml
+++ b/compiler/ReorderDecls.ml
@@ -6,7 +6,11 @@ open Pure
(** The local logger *)
let log = Logging.reorder_decls_log
-type fun_id = { def_id : FunDeclId.id; rg_id : T.RegionGroupId.id option }
+type fun_id = {
+ def_id : FunDeclId.id;
+ lp_id : V.LoopId.id option;
+ rg_id : T.RegionGroupId.id option;
+}
[@@deriving show, ord]
module FunIdOrderedType : OrderedType with type t = fun_id = struct
@@ -38,11 +42,11 @@ let compute_body_fun_deps (e : texpression) : FunIdSet.t =
| FunOrOp (Fun fid) -> (
match fid with
| Pure _ -> ()
- | FromLlbc (fid, rg_id) -> (
+ | FromLlbc (fid, lp_id, rg_id) -> (
match fid with
| Assumed _ -> ()
| Regular fid ->
- let id = { def_id = fid; rg_id } in
+ let id = { def_id = fid; lp_id; rg_id } in
ids := FunIdSet.add id !ids))
end
in
@@ -66,7 +70,7 @@ let group_reorder_fun_decls (decls : fun_decl list) :
(bool * fun_decl list) list =
let module IntMap = MakeMap (OrderedInt) in
let get_fun_id (decl : fun_decl) : fun_id =
- { def_id = decl.def_id; rg_id = decl.back_id }
+ { def_id = decl.def_id; lp_id = decl.loop_id; rg_id = decl.back_id }
in
(* Compute the list/set of identifiers *)
let idl = List.map get_fun_id decls in
diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml
index 8a6f82f9..bba3326b 100644
--- a/compiler/SymbolicToPure.ml
+++ b/compiler/SymbolicToPure.ml
@@ -269,7 +269,7 @@ let bs_ctx_register_backward_call (abs : V.abs) (call_id : V.FunCallId.id)
(* Retrieve the fun_id *)
let fun_id =
match info.forward.call_id with
- | S.Fun (fid, _) -> Fun (FromLlbc (fid, Some back_id))
+ | S.Fun (fid, _) -> Fun (FromLlbc (fid, None, Some back_id))
| S.Unop _ | S.Binop _ -> raise (Failure "Unreachable")
in
(* Update the context and return *)
@@ -1268,7 +1268,7 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) :
match call.call_id with
| S.Fun (fid, call_id) ->
(* Regular function call *)
- let func = Fun (FromLlbc (fid, None)) in
+ let func = Fun (FromLlbc (fid, None, None)) in
(* Retrieve the effect information about this function (can fail,
* takes a state as input, etc.) *)
let effect_info =
@@ -2027,10 +2027,17 @@ let translate_fun_decl (ctx : bs_ctx) (body : S.expression option) : fun_decl =
(List.combine inputs signature.inputs));
Some { inputs; inputs_lvs; body }
in
+
+ (* Note that for now, the loops are still *inside* the function body: we will
+ extract them from there later, in {!PureMicroPasses} (by "splitting" the definition).
+ *)
+ let loop_id = None in
+
(* Assemble the declaration *)
let def =
{
def_id;
+ loop_id;
back_id = bid;
basename;
signature;
diff --git a/compiler/Translate.ml b/compiler/Translate.ml
index 87862d6b..75aeb37c 100644
--- a/compiler/Translate.ml
+++ b/compiler/Translate.ml
@@ -459,9 +459,7 @@ let export_functions_declarations (fmt : Format.formatter) (config : gen_config)
in
(* Extract the function declarations *)
- (* Check if the functions are mutually recursive - this really works
- * to check if the forward and backward translations of a single
- * recursive function are mutually recursive *)
+ (* Check if the functions are mutually recursive *)
let is_mut_rec = List.length decls > 1 in
assert ((not is_mut_rec) || is_rec);
let decls_length = List.length decls in