summaryrefslogtreecommitdiff
path: root/compiler
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--compiler/Config.ml107
-rw-r--r--compiler/Contexts.ml2
-rw-r--r--compiler/Extract.ml444
-rw-r--r--compiler/ExtractBase.ml243
-rw-r--r--compiler/ExtractBuiltin.ml138
-rw-r--r--compiler/ExtractTypes.ml2
-rw-r--r--compiler/Main.ml9
-rw-r--r--compiler/PrintPure.ml18
-rw-r--r--compiler/Pure.ml4
-rw-r--r--compiler/PureMicroPasses.ml331
-rw-r--r--compiler/ReorderDecls.ml12
-rw-r--r--compiler/SymbolicToPure.ml949
-rw-r--r--compiler/Translate.ml145
-rw-r--r--compiler/TranslateCore.ml15
14 files changed, 696 insertions, 1723 deletions
diff --git a/compiler/Config.ml b/compiler/Config.ml
index 3b0070c0..af0e62d1 100644
--- a/compiler/Config.ml
+++ b/compiler/Config.ml
@@ -92,69 +92,6 @@ let loop_fixed_point_max_num_iters = 2
(** {1 Translation} *)
-(** If true, do not define separate forward/backward functions, but make the
- forward functions return the backward function.
-
- Example:
- {[
- (* Rust *)
- pub fn list_nth<'a, T>(l: &'a mut List<T>, i: u32) -> &'a mut T {
- match l {
- List::Nil => {
- panic!()
- }
- List::Cons(x, tl) => {
- if i == 0 {
- x
- } else {
- list_nth(tl, i - 1)
- }
- }
- }
- }
-
- (* Translation, if return_back_funs = false *)
- def list_nth (T : Type) (l : List T) (i : U32) : Result T :=
- match l with
- | List.Cons x tl =>
- if i = 0#u32
- then Result.ret x
- else do
- let i0 ← i - 1#u32
- list_nth T tl i0
- | List.Nil => Result.fail .panic
-
- def list_nth_back
- (T : Type) (l : List T) (i : U32) (ret : T) : Result (List T) :=
- match l with
- | List.Cons x tl =>
- if i = 0#u32
- then Result.ret (List.Cons ret tl)
- else
- do
- let i0 ← i - 1#u32
- let tl0 ← list_nth_back T tl i0 ret
- Result.ret (List.Cons x tl0)
- | List.Nil => Result.fail .panic
-
- (* Translation, if return_back_funs = true *)
- def list_nth (T: Type) (ls : List T) (i : U32) :
- Result (T × (T → Result (List T))) :=
- match ls with
- | List.Cons x tl =>
- if i = 0#u32
- then Result.ret (x, (λ ret => return (ret :: ls)))
- else do
- let i0 ← i - 1#u32
- let (x, back) ← list_nth ls i0
- Return.ret (x,
- (λ ret => do
- let ls ← back ret
- return (x :: ls)))
- ]}
- *)
-let return_back_funs = ref true
-
(** Forbids using field projectors for structures.
If we don't use field projectors, whenever we symbolically expand a structure
@@ -326,50 +263,6 @@ let decompose_nested_let_patterns = ref false
*)
let unfold_monadic_let_bindings = ref false
-(** Controls whether we try to filter the calls to monadic functions
- (which can fail) when their outputs are not used.
-
- The useless calls are calls to backward functions which have no outputs.
- This case happens if the original Rust function only takes *shared* borrows
- as inputs, and is thus pretty common.
-
- We are allowed to do this only because in this specific case,
- the backward function fails *exactly* when the forward function fails
- (they actually do exactly the same thing, the only difference being
- that the forward function can potentially return a value), and upon
- reaching the place where we should introduce a call to the backward
- function, we know we have introduced a call to the forward function.
-
- Also note that in general, backward functions "do more things" than
- forward functions, and have more opportunities to fail (even though
- in the generated code, calls to the backward functions should fail
- exactly when the corresponding, previous call to the forward functions
- failed).
-
- This optimization is done in {!SymbolicToPure}. We might want to move it to
- the micro-passes subsequent to the translation from symbolic to pure, but it
- is really super easy to do it when going from symbolic to pure. Note that
- we later filter the useless *forward* calls in the micro-passes, where it is
- more natural to do.
-
- See the comments for {!PureMicroPasses.expression_contains_child_call_in_all_paths}
- for additional explanations.
- *)
-let filter_useless_monadic_calls = ref true
-
-(** If {!filter_useless_monadic_calls} is activated, some functions
- become useless: if this option is true, we don't extract them.
-
- The calls to functions which always get filtered are:
- - the forward functions with unit return value
- - the backward functions which don't output anything (backward
- functions coming from rust functions with no mutable borrows
- as input values - note that if a function doesn't take mutable
- borrows as inputs, it can't return mutable borrows; we actually
- dynamically check for that).
- *)
-let filter_useless_functions = ref true
-
(** Simplify the forward/backward functions, in case we merge them
(i.e., the forward functions return the backward functions).
diff --git a/compiler/Contexts.ml b/compiler/Contexts.ml
index b1dd9553..54411fd5 100644
--- a/compiler/Contexts.ml
+++ b/compiler/Contexts.ml
@@ -109,6 +109,8 @@ let reset_global_counters () =
region_id_counter := RegionId.generator_zero;
abstraction_id_counter := AbstractionId.generator_zero;
loop_id_counter := LoopId.generator_zero;
+ (* We want the loop id to start at 1 *)
+ let _ = fresh_loop_id () in
fun_call_id_counter := FunCallId.generator_zero;
dummy_var_id_counter := DummyVarId.generator_zero
diff --git a/compiler/Extract.ml b/compiler/Extract.ml
index dbca4f8f..794a1bfa 100644
--- a/compiler/Extract.ml
+++ b/compiler/Extract.ml
@@ -9,8 +9,7 @@ open TranslateCore
open Config
include ExtractTypes
-(** Compute the names for all the pure functions generated from a rust function
- (forward function and backward functions).
+(** Compute the names for all the pure functions generated from a rust function.
*)
let extract_fun_decl_register_names (ctx : extraction_ctx)
(has_decreases_clause : fun_decl -> bool) (def : pure_fun_translation) :
@@ -19,63 +18,36 @@ let extract_fun_decl_register_names (ctx : extraction_ctx)
method implementations): we do not need to refer to them directly. We will
only use their type for the fields of the records we generate for the trait
declarations *)
- match def.fwd.f.kind with
+ match def.f.kind with
| TraitMethodDecl _ -> ctx
| _ -> (
(* Check if the function is builtin *)
let builtin =
let open ExtractBuiltin in
let funs_map = builtin_funs_map () in
- match_name_find_opt ctx.trans_ctx def.fwd.f.llbc_name funs_map
+ match_name_find_opt ctx.trans_ctx def.f.llbc_name funs_map
in
(* Use the builtin names if necessary *)
match builtin with
- | Some (filter_info, info) ->
- (* Register the filtering information, if there is *)
+ | Some (filter_info, fun_info) ->
+ (* Builtin function: register the filtering information, if there is *)
let ctx =
match filter_info with
| Some keep ->
{
ctx with
funs_filter_type_args_map =
- FunDeclId.Map.add def.fwd.f.def_id keep
+ FunDeclId.Map.add def.f.def_id keep
ctx.funs_filter_type_args_map;
}
| _ -> ctx
in
- let funs =
- if !Config.return_back_funs then [ def.fwd.f ]
- else
- let backs = List.map (fun f -> f.f) def.backs in
- if def.keep_fwd then def.fwd.f :: backs else backs
- in
- List.fold_left
- (fun ctx (f : fun_decl) ->
- let open ExtractBuiltin in
- let fun_id =
- (Pure.FunId (FRegular f.def_id), f.loop_id, f.back_id)
- in
- let fun_info =
- List.find_opt
- (fun (x : builtin_fun_info) -> x.rg = f.back_id)
- info
- in
- match fun_info with
- | Some fun_info ->
- ctx_add (FunId (FromLlbc fun_id)) fun_info.extract_name ctx
- | None ->
- raise
- (Failure
- ("Not found: "
- ^ name_to_string ctx f.llbc_name
- ^ ", "
- ^ Print.option_to_string Pure.show_loop_id f.loop_id
- ^ Print.option_to_string Pure.show_region_group_id
- f.back_id)))
- ctx funs
+ let f = def.f in
+ let open ExtractBuiltin in
+ let fun_id = (Pure.FunId (FRegular f.def_id), f.loop_id) in
+ ctx_add (FunId (FromLlbc fun_id)) fun_info.extract_name ctx
| None ->
- let fwd = def.fwd in
- let backs = def.backs in
+ (* Not builtin *)
(* Register the decrease clauses, if necessary *)
let register_decreases ctx def =
if has_decreases_clause def then
@@ -88,21 +60,15 @@ let extract_fun_decl_register_names (ctx : extraction_ctx)
| Lean -> ctx_add_decreases_proof def ctx
else ctx
in
- let ctx =
- List.fold_left register_decreases ctx (fwd.f :: fwd.loops)
- in
- let register_fun ctx f = ctx_add_fun_decl def f ctx in
+ (* We have to register the function itself, and the loops it
+ may contain (which are extracted as functions) *)
+ let funs = def.f :: def.loops in
+ (* Register the decrease clauses *)
+ let ctx = List.fold_left register_decreases ctx funs in
+ (* Register the name of the function and the loops *)
+ let register_fun ctx f = ctx_add_fun_decl f ctx in
let register_funs ctx fl = List.fold_left register_fun ctx fl in
- (* Register the names of the forward functions *)
- let ctx =
- if def.keep_fwd then register_funs ctx (fwd.f :: fwd.loops) else ctx
- in
- (* Register the names of the backward functions *)
- List.fold_left
- (fun ctx { f = back; loops = loop_backs } ->
- let ctx = register_fun ctx back in
- register_funs ctx loop_backs)
- ctx backs)
+ register_funs ctx funs)
(** Simply add the global name to the context. *)
let extract_global_decl_register_names (ctx : extraction_ctx)
@@ -230,7 +196,7 @@ let fun_builtin_filter_types (id : FunDeclId.id) (types : 'a list)
let decl = FunDeclId.Map.find id ctx.trans_funs in
let err =
"Ill-formed builtin information for function "
- ^ name_to_string ctx decl.fwd.f.llbc_name
+ ^ name_to_string ctx decl.f.llbc_name
^ ": "
^ string_of_int (List.length filter)
^ " filtering arguments provided for "
@@ -460,8 +426,7 @@ and extract_function_call (ctx : extraction_ctx) (fmt : F.formatter)
]}
*)
(match fun_id with
- | FromLlbc
- (TraitMethod (trait_ref, method_name, _fun_decl_id), lp_id, rg_id) ->
+ | FromLlbc (TraitMethod (trait_ref, method_name, _fun_decl_id), lp_id) ->
(* We have to check whether the trait method is required or provided *)
let trait_decl_id = trait_ref.trait_decl_ref.trait_decl_id in
let trait_decl =
@@ -477,7 +442,7 @@ and extract_function_call (ctx : extraction_ctx) (fmt : F.formatter)
extract_trait_ref ctx fmt TypeDeclId.Set.empty true trait_ref;
let fun_name =
ctx_get_trait_method trait_ref.trait_decl_ref.trait_decl_id
- method_name rg_id ctx
+ method_name ctx
in
let add_brackets (s : string) =
if !backend = Coq then "(" ^ s ^ ")" else s
@@ -486,9 +451,7 @@ and extract_function_call (ctx : extraction_ctx) (fmt : F.formatter)
else
(* Provided method: we see it as a regular function call, and use
the function name *)
- let fun_id =
- FromLlbc (FunId (FRegular method_id.id), lp_id, rg_id)
- in
+ let fun_id = FromLlbc (FunId (FRegular method_id.id), lp_id) in
let fun_name = ctx_get_function fun_id ctx in
F.pp_print_string fmt fun_name;
@@ -513,7 +476,7 @@ and extract_function_call (ctx : extraction_ctx) (fmt : F.formatter)
*)
let types =
match fun_id with
- | FromLlbc (FunId (FRegular id), _, _) ->
+ | FromLlbc (FunId (FRegular id), _) ->
fun_builtin_filter_types id generics.types ctx
| _ -> Result.Ok generics.types
in
@@ -1392,11 +1355,6 @@ let extract_template_lean_termination_and_decreasing (ctx : extraction_ctx)
let extract_fun_comment (ctx : extraction_ctx) (fmt : F.formatter)
(def : fun_decl) : unit =
- let { keep_fwd; num_backs } =
- PureUtils.RegularFunIdMap.find
- (Pure.FunId (FRegular def.def_id), def.loop_id, def.back_id)
- ctx.fun_name_info
- in
let comment_pre = "[" ^ name_to_string ctx def.llbc_name ^ "]:" in
let comment =
let loop_comment =
@@ -1404,23 +1362,7 @@ let extract_fun_comment (ctx : extraction_ctx) (fmt : F.formatter)
| None -> ""
| Some id -> " loop " ^ LoopId.to_string id ^ ":"
in
- let fwd_back_comment =
- match def.back_id with
- | None -> if !Config.return_back_funs then [] else [ "forward function" ]
- | Some id ->
- (* Check if there is only one backward function, and no forward function *)
- if (not keep_fwd) && num_backs = 1 then
- [
- "merged forward/backward function";
- "(there is a single backward function, and the forward function \
- returns ())";
- ]
- else [ "backward function " ^ T.RegionGroupId.to_string id ]
- in
- match fwd_back_comment with
- | [] -> [ comment_pre ^ loop_comment ]
- | [ s ] -> [ comment_pre ^ loop_comment ^ " " ^ s ]
- | s :: sl -> (comment_pre ^ loop_comment ^ " " ^ s) :: sl
+ [ comment_pre ^ loop_comment ]
in
extract_comment_with_span fmt comment def.meta.span
@@ -1435,9 +1377,7 @@ let extract_fun_decl_gen (ctx : extraction_ctx) (fmt : F.formatter)
(kind : decl_kind) (has_decreases_clause : bool) (def : fun_decl) : unit =
assert (not def.is_global_decl_body);
(* Retrieve the function name *)
- let def_name =
- ctx_get_local_function def.def_id def.loop_id def.back_id ctx
- in
+ let def_name = ctx_get_local_function def.def_id def.loop_id ctx in
(* Add a break before *)
if !backend <> HOL4 || not (decl_is_first_from_group kind) then
F.pp_print_break fmt 0 0;
@@ -1681,9 +1621,7 @@ let extract_fun_decl_gen (ctx : extraction_ctx) (fmt : F.formatter)
let extract_fun_decl_hol4_opaque (ctx : extraction_ctx) (fmt : F.formatter)
(def : fun_decl) : unit =
(* Retrieve the definition name *)
- let def_name =
- ctx_get_local_function def.def_id def.loop_id def.back_id ctx
- in
+ let def_name = ctx_get_local_function def.def_id def.loop_id ctx in
assert (def.signature.generics.const_generics = []);
(* Add the type/const gen parameters - note that we need those bindings
only for the generation of the type (they are not top-level) *)
@@ -1870,7 +1808,6 @@ let extract_global_decl_hol4_opaque (ctx : extraction_ctx) (fmt : F.formatter)
let extract_global_decl (ctx : extraction_ctx) (fmt : F.formatter)
(global : A.global_decl) (body : fun_decl) (interface : bool) : unit =
assert body.is_global_decl_body;
- assert (Option.is_none body.back_id);
assert (body.signature.inputs = []);
assert (body.signature.generics = empty_generic_params);
@@ -1883,9 +1820,7 @@ let extract_global_decl (ctx : extraction_ctx) (fmt : F.formatter)
let decl_name = ctx_get_global global.def_id ctx in
let body_name =
- ctx_get_function
- (FromLlbc (Pure.FunId (FRegular global.body), None, None))
- ctx
+ ctx_get_function (FromLlbc (Pure.FunId (FRegular global.body), None)) ctx
in
let decl_ty, body_ty =
@@ -2058,80 +1993,45 @@ let extract_trait_decl_method_names (ctx : extraction_ctx)
let required_methods = trait_decl.required_methods in
(* Compute the names *)
let method_names =
- (* We add one field per required forward/backward function *)
- let get_funs_for_id (id : fun_decl_id) : fun_decl list =
- let trans : pure_fun_translation = FunDeclId.Map.find id ctx.trans_funs in
- if !Config.return_back_funs then [ trans.fwd.f ]
- else List.map (fun f -> f.f) (trans.fwd :: trans.backs)
- in
match builtin_info with
| None ->
- (* We add one field per required forward/backward function *)
- let compute_item_names (item_name : string) (id : fun_decl_id) :
- string * (RegionGroupId.id option * string) list =
- let compute_fun_name (f : fun_decl) : RegionGroupId.id option * string
- =
- (* We do something special to reuse the [ctx_compute_fun_decl]
- function. TODO: make it cleaner. *)
- let llbc_name : Types.name =
- [ Types.PeIdent (item_name, Disambiguator.zero) ]
- in
- let f = { f with llbc_name } in
- let trans = A.FunDeclId.Map.find f.def_id ctx.trans_funs in
- let name = ctx_compute_fun_name trans f ctx in
- (* Add a prefix if necessary *)
- let name =
- if !Config.record_fields_short_names then name
- else ctx_compute_trait_decl_name ctx trait_decl ^ "_" ^ name
- in
- (f.back_id, name)
+ (* Not a builtin function *)
+ let compute_item_name (item_name : string) (id : fun_decl_id) :
+ string * string =
+ let trans : pure_fun_translation =
+ FunDeclId.Map.find id ctx.trans_funs
+ in
+ let f = trans.f in
+ (* We do something special to reuse the [ctx_compute_fun_decl]
+ function. TODO: make it cleaner. *)
+ let llbc_name : Types.name =
+ [ Types.PeIdent (item_name, Disambiguator.zero) ]
+ in
+ let f = { f with llbc_name } in
+ let name = ctx_compute_fun_name f ctx in
+ (* Add a prefix if necessary *)
+ let name =
+ if !Config.record_fields_short_names then name
+ else ctx_compute_trait_decl_name ctx trait_decl ^ "_" ^ name
in
- let funs = get_funs_for_id id in
- (item_name, List.map compute_fun_name funs)
+ (item_name, name)
in
- List.map (fun (name, id) -> compute_item_names name id) required_methods
+ List.map (fun (name, id) -> compute_item_name name id) required_methods
| Some info ->
+ (* This is a builtin *)
let funs_map = StringMap.of_list info.methods in
List.map
- (fun (item_name, fun_id) ->
+ (fun (item_name, _) ->
let open ExtractBuiltin in
let info = StringMap.find item_name funs_map in
- let trans_funs = get_funs_for_id fun_id in
- let find (trans_fun : fun_decl) =
- let info =
- List.find_opt
- (fun (info : builtin_fun_info) -> info.rg = trans_fun.back_id)
- info
- in
- match info with
- | Some info -> (info.rg, info.extract_name)
- | None ->
- let err =
- "Ill-formed builtin information for trait decl \""
- ^ name_to_string ctx trait_decl.llbc_name
- ^ "\", method \"" ^ item_name
- ^ "\": could not find name for region "
- ^ Print.option_to_string Pure.show_region_group_id
- trans_fun.back_id
- in
- log#serror err;
- if !Config.fail_hard then raise (Failure err)
- else (trans_fun.back_id, "%ERROR_BUILTIN_NAME_NOT_FOUND%")
- in
- let rg_with_name_list = List.map find trans_funs in
- (item_name, rg_with_name_list))
+ let fun_name = info.extract_name in
+ (item_name, fun_name))
required_methods
in
(* Register the names *)
List.fold_left
- (fun ctx (item_name, funs) ->
- (* We add one field per required forward/backward function *)
- List.fold_left
- (fun ctx (rg, fun_name) ->
- ctx_add
- (TraitMethodId (trait_decl.def_id, item_name, rg))
- fun_name ctx)
- ctx funs)
+ (fun ctx (item_name, fun_name) ->
+ ctx_add (TraitMethodId (trait_decl.def_id, item_name)) fun_name ctx)
ctx method_names
(** Similar to {!extract_type_decl_register_names} *)
@@ -2263,46 +2163,41 @@ let extract_trait_decl_method_items (ctx : extraction_ctx) (fmt : F.formatter)
(* Lookup the definition *)
let trans = A.FunDeclId.Map.find id ctx.trans_funs in
(* Extract the items *)
- let funs = if trans.keep_fwd then trans.fwd :: trans.backs else trans.backs in
- let extract_method (f : fun_and_loops) =
- let f = f.f in
- let fun_name = ctx_get_trait_method decl.def_id item_name f.back_id ctx in
- let ty () =
- (* Extract the generics *)
- (* We need to add the generics specific to the method, by removing those
- which actually apply to the trait decl *)
- let generics =
- let drop_trait_clauses = false in
- generic_params_drop_prefix ~drop_trait_clauses decl.generics
- f.signature.generics
- in
- (* Note that we do not filter the LLBC generic parameters.
- This is ok because:
- - we only use them to find meaningful names for the trait clauses
- - we only generate trait clauses for the clauses we find in the
- pure generics *)
- let ctx, type_params, cg_params, trait_clauses =
- ctx_add_generic_params f.llbc_name f.signature.llbc_generics generics
- ctx
- in
- let backend_uses_forall =
- match !backend with Coq | Lean -> true | FStar | HOL4 -> false
- in
- let generics_not_empty = generics <> empty_generic_params in
- let use_forall = generics_not_empty && backend_uses_forall in
- let use_arrows = generics_not_empty && not backend_uses_forall in
- let use_forall_use_sep = false in
- extract_generic_params ctx fmt TypeDeclId.Set.empty ~use_forall
- ~use_forall_use_sep ~use_arrows generics type_params cg_params
- trait_clauses;
- if use_forall then F.pp_print_string fmt ",";
- (* Extract the inputs and output *)
- F.pp_print_space fmt ();
- extract_fun_inputs_output_parameters_types ctx fmt f
+ let f = trans.f in
+ let fun_name = ctx_get_trait_method decl.def_id item_name ctx in
+ let ty () =
+ (* Extract the generics *)
+ (* We need to add the generics specific to the method, by removing those
+ which actually apply to the trait decl *)
+ let generics =
+ let drop_trait_clauses = false in
+ generic_params_drop_prefix ~drop_trait_clauses decl.generics
+ f.signature.generics
+ in
+ (* Note that we do not filter the LLBC generic parameters.
+ This is ok because:
+ - we only use them to find meaningful names for the trait clauses
+ - we only generate trait clauses for the clauses we find in the
+ pure generics *)
+ let ctx, type_params, cg_params, trait_clauses =
+ ctx_add_generic_params f.llbc_name f.signature.llbc_generics generics ctx
in
- extract_trait_decl_item ctx fmt fun_name ty
+ let backend_uses_forall =
+ match !backend with Coq | Lean -> true | FStar | HOL4 -> false
+ in
+ let generics_not_empty = generics <> empty_generic_params in
+ let use_forall = generics_not_empty && backend_uses_forall in
+ let use_arrows = generics_not_empty && not backend_uses_forall in
+ let use_forall_use_sep = false in
+ extract_generic_params ctx fmt TypeDeclId.Set.empty ~use_forall
+ ~use_forall_use_sep ~use_arrows generics type_params cg_params
+ trait_clauses;
+ if use_forall then F.pp_print_string fmt ",";
+ (* Extract the inputs and output *)
+ F.pp_print_space fmt ();
+ extract_fun_inputs_output_parameters_types ctx fmt f
in
- List.iter extract_method funs
+ extract_trait_decl_item ctx fmt fun_name ty
(** Extract a trait declaration *)
let extract_trait_decl (ctx : extraction_ctx) (fmt : F.formatter)
@@ -2494,21 +2389,10 @@ let extract_trait_decl_coq_arguments (ctx : extraction_ctx) (fmt : F.formatter)
decl.parent_clauses;
(* The required methods *)
List.iter
- (fun (item_name, id) ->
- (* Lookup the definition *)
- let trans = A.FunDeclId.Map.find id ctx.trans_funs in
+ (fun (item_name, _) ->
(* Extract the items *)
- let funs =
- if trans.keep_fwd then trans.fwd :: trans.backs else trans.backs
- in
- let extract_for_method (f : fun_and_loops) =
- let f = f.f in
- let item_name =
- ctx_get_trait_method decl.def_id item_name f.back_id ctx
- in
- extract_coq_arguments_instruction ctx fmt item_name num_params
- in
- List.iter extract_for_method funs)
+ let item_name = ctx_get_trait_method decl.def_id item_name ctx in
+ extract_coq_arguments_instruction ctx fmt item_name num_params)
decl.required_methods;
(* Add a space *)
F.pp_print_space fmt ())
@@ -2531,75 +2415,71 @@ let extract_trait_impl_method_items (ctx : extraction_ctx) (fmt : F.formatter)
(* Lookup the definition *)
let trans = A.FunDeclId.Map.find id ctx.trans_funs in
(* Extract the items *)
- let funs = if trans.keep_fwd then trans.fwd :: trans.backs else trans.backs in
- let extract_method (f : fun_and_loops) =
- let f = f.f in
- let fun_name = ctx_get_trait_method trait_decl_id item_name f.back_id ctx in
- let ty () =
- (* Filter the generics if the method is a builtin *)
- let i_tys, _, _ = impl_generics in
- let impl_types, i_tys, f_tys =
- match FunDeclId.Map.find_opt f.def_id ctx.funs_filter_type_args_map with
- | None -> (impl.generics.types, i_tys, f.signature.generics.types)
- | Some filter ->
- let filter_list filter ls =
- let ls = List.combine filter ls in
- List.filter_map (fun (b, ty) -> if b then Some ty else None) ls
- in
- let impl_types = impl.generics.types in
- let impl_filter =
- Collections.List.prefix (List.length impl_types) filter
- in
- let i_tys = i_tys in
- let i_filter = Collections.List.prefix (List.length i_tys) filter in
- ( filter_list impl_filter impl_types,
- filter_list i_filter i_tys,
- filter_list filter f.signature.generics.types )
- in
- let f_generics = { f.signature.generics with types = f_tys } in
- (* Extract the generics - we need to quantify over the generics which
- are specific to the method, and call it will all the generics
- (trait impl + method generics) *)
- let f_generics =
- let drop_trait_clauses = true in
- generic_params_drop_prefix ~drop_trait_clauses
- { impl.generics with types = impl_types }
- f_generics
- in
- (* Register and print the quantified generics.
-
- Note that we do not filter the LLBC generic parameters.
- This is ok because:
- - we only use them to find meaningful names for the trait clauses
- - we only generate trait clauses for the clauses we find in the
- pure generics *)
- let ctx, f_tys, f_cgs, f_tcs =
- ctx_add_generic_params f.llbc_name f.signature.llbc_generics f_generics
- ctx
- in
- let use_forall = f_generics <> empty_generic_params in
- extract_generic_params ctx fmt TypeDeclId.Set.empty ~use_forall f_generics
- f_tys f_cgs f_tcs;
- if use_forall then F.pp_print_string fmt ",";
- (* Extract the function call *)
- F.pp_print_space fmt ();
- let fun_name = ctx_get_local_function f.def_id None f.back_id ctx in
- F.pp_print_string fmt fun_name;
- let all_generics =
- let _, i_cgs, i_tcs = impl_generics in
- List.concat [ i_tys; f_tys; i_cgs; f_cgs; i_tcs; f_tcs ]
- in
-
- (* Filter the generics if the function is builtin *)
- List.iter
- (fun p ->
- F.pp_print_space fmt ();
- F.pp_print_string fmt p)
- all_generics
+ let f = trans.f in
+ let fun_name = ctx_get_trait_method trait_decl_id item_name ctx in
+ let ty () =
+ (* Filter the generics if the method is a builtin *)
+ let i_tys, _, _ = impl_generics in
+ let impl_types, i_tys, f_tys =
+ match FunDeclId.Map.find_opt f.def_id ctx.funs_filter_type_args_map with
+ | None -> (impl.generics.types, i_tys, f.signature.generics.types)
+ | Some filter ->
+ let filter_list filter ls =
+ let ls = List.combine filter ls in
+ List.filter_map (fun (b, ty) -> if b then Some ty else None) ls
+ in
+ let impl_types = impl.generics.types in
+ let impl_filter =
+ Collections.List.prefix (List.length impl_types) filter
+ in
+ let i_tys = i_tys in
+ let i_filter = Collections.List.prefix (List.length i_tys) filter in
+ ( filter_list impl_filter impl_types,
+ filter_list i_filter i_tys,
+ filter_list filter f.signature.generics.types )
+ in
+ let f_generics = { f.signature.generics with types = f_tys } in
+ (* Extract the generics - we need to quantify over the generics which
+ are specific to the method, and call it will all the generics
+ (trait impl + method generics) *)
+ let f_generics =
+ let drop_trait_clauses = true in
+ generic_params_drop_prefix ~drop_trait_clauses
+ { impl.generics with types = impl_types }
+ f_generics
+ in
+ (* Register and print the quantified generics.
+
+ Note that we do not filter the LLBC generic parameters.
+ This is ok because:
+ - we only use them to find meaningful names for the trait clauses
+ - we only generate trait clauses for the clauses we find in the
+ pure generics *)
+ let ctx, f_tys, f_cgs, f_tcs =
+ ctx_add_generic_params f.llbc_name f.signature.llbc_generics f_generics
+ ctx
+ in
+ let use_forall = f_generics <> empty_generic_params in
+ extract_generic_params ctx fmt TypeDeclId.Set.empty ~use_forall f_generics
+ f_tys f_cgs f_tcs;
+ if use_forall then F.pp_print_string fmt ",";
+ (* Extract the function call *)
+ F.pp_print_space fmt ();
+ let fun_name = ctx_get_local_function f.def_id None ctx in
+ F.pp_print_string fmt fun_name;
+ let all_generics =
+ let _, i_cgs, i_tcs = impl_generics in
+ List.concat [ i_tys; f_tys; i_cgs; f_cgs; i_tcs; f_tcs ]
in
- extract_trait_impl_item ctx fmt fun_name ty
+
+ (* Filter the generics if the function is builtin *)
+ List.iter
+ (fun p ->
+ F.pp_print_space fmt ();
+ F.pp_print_string fmt p)
+ all_generics
in
- List.iter extract_method funs
+ extract_trait_impl_item ctx fmt fun_name ty
(** Extract a trait implementation *)
let extract_trait_impl (ctx : extraction_ctx) (fmt : F.formatter)
@@ -2766,8 +2646,6 @@ let extract_trait_impl (ctx : extraction_ctx) (fmt : F.formatter)
*)
let extract_unit_test_if_unit_fun (ctx : extraction_ctx) (fmt : F.formatter)
(def : fun_decl) : unit =
- (* We only insert unit tests for forward functions *)
- assert (def.back_id = None);
(* Check if this is a unit function *)
let sg = def.signature in
if
@@ -2791,9 +2669,7 @@ let extract_unit_test_if_unit_fun (ctx : extraction_ctx) (fmt : F.formatter)
F.pp_print_string fmt "assert_norm";
F.pp_print_space fmt ();
F.pp_print_string fmt "(";
- let fun_name =
- ctx_get_local_function def.def_id def.loop_id def.back_id ctx
- in
+ let fun_name = ctx_get_local_function def.def_id def.loop_id ctx in
F.pp_print_string fmt fun_name;
if sg.inputs <> [] then (
F.pp_print_space fmt ();
@@ -2807,9 +2683,7 @@ let extract_unit_test_if_unit_fun (ctx : extraction_ctx) (fmt : F.formatter)
F.pp_print_string fmt "Check";
F.pp_print_space fmt ();
F.pp_print_string fmt "(";
- let fun_name =
- ctx_get_local_function def.def_id def.loop_id def.back_id ctx
- in
+ let fun_name = ctx_get_local_function def.def_id def.loop_id ctx in
F.pp_print_string fmt fun_name;
if sg.inputs <> [] then (
F.pp_print_space fmt ();
@@ -2820,9 +2694,7 @@ let extract_unit_test_if_unit_fun (ctx : extraction_ctx) (fmt : F.formatter)
F.pp_print_string fmt "#assert";
F.pp_print_space fmt ();
F.pp_print_string fmt "(";
- let fun_name =
- ctx_get_local_function def.def_id def.loop_id def.back_id ctx
- in
+ let fun_name = ctx_get_local_function def.def_id def.loop_id ctx in
F.pp_print_string fmt fun_name;
if sg.inputs <> [] then (
F.pp_print_space fmt ();
@@ -2835,9 +2707,7 @@ let extract_unit_test_if_unit_fun (ctx : extraction_ctx) (fmt : F.formatter)
| HOL4 ->
F.pp_print_string fmt "val _ = assert_return (";
F.pp_print_string fmt "“";
- let fun_name =
- ctx_get_local_function def.def_id def.loop_id def.back_id ctx
- in
+ let fun_name = ctx_get_local_function def.def_id def.loop_id ctx in
F.pp_print_string fmt fun_name;
if sg.inputs <> [] then (
F.pp_print_space fmt ();
diff --git a/compiler/ExtractBase.ml b/compiler/ExtractBase.ml
index 5aa8323e..591e8aab 100644
--- a/compiler/ExtractBase.ml
+++ b/compiler/ExtractBase.ml
@@ -167,11 +167,7 @@ type id =
| TraitImplId of TraitImplId.id
| LocalTraitClauseId of TraitClauseId.id
| TraitDeclConstructorId of TraitDeclId.id
- | TraitMethodId of TraitDeclId.id * string * T.RegionGroupId.id option
- (** Something peculiar with trait methods: because we have to take into
- account forward/backward functions, we may need to generate fields
- items per method.
- *)
+ | TraitMethodId of TraitDeclId.id * string
| TraitItemId of TraitDeclId.id * string
(** A trait associated item which is not a method *)
| TraitParentClauseId of TraitDeclId.id * TraitClauseId.id
@@ -353,8 +349,6 @@ let basename_to_unique (names_set : StringSet.t)
in
if StringSet.mem basename names_set then gen 1 else basename
-type fun_name_info = { keep_fwd : bool; num_backs : int }
-
type names_maps = {
names_map : names_map;
(** The map for id to names, where we forbid name collisions
@@ -384,7 +378,7 @@ let allow_collisions (id : id) : bool =
| FieldId _ | TraitItemClauseId _ | TraitParentClauseId _ | TraitItemId _
| TraitMethodId _ ->
!Config.record_fields_short_names
- | FunId (Pure _ | FromLlbc (FunId (FAssumed _), _, _)) ->
+ | FunId (Pure _ | FromLlbc (FunId (FAssumed _), _)) ->
(* We map several assumed functions to the same id *)
true
| _ -> false
@@ -471,8 +465,7 @@ type names_map_init = {
assumed_adts : (assumed_ty * string) list;
assumed_structs : (assumed_ty * string) list;
assumed_variants : (assumed_ty * VariantId.id * string) list;
- assumed_llbc_functions :
- (A.assumed_fun_id * RegionGroupId.id option * string) list;
+ assumed_llbc_functions : (A.assumed_fun_id * string) list;
assumed_pure_functions : (pure_assumed_fun_id * string) list;
}
@@ -550,15 +543,6 @@ type extraction_ctx = {
-- makes the if then else dependent
]}
*)
- fun_name_info : fun_name_info PureUtils.RegularFunIdMap.t;
- (** Information used to filter and name functions - we use it
- to print comments in the generated code, to help link
- the generated code to the original code (information such
- as: "this function is the backward function of ...", or
- "this function is the merged forward/backward function of ..."
- in case a Rust function only has one backward translation
- and we filter the forward function because it returns unit.
- *)
trait_decl_id : trait_decl_id option;
(** If we are extracting a trait declaration, identifies it *)
is_provided_method : bool;
@@ -669,14 +653,8 @@ let id_to_string (id : id) (ctx : extraction_ctx) : string =
^ TraitClauseId.to_string clause_id
| TraitItemId (id, name) ->
"trait_item_id: " ^ trait_decl_id_to_string id ^ ", type name: " ^ name
- | TraitMethodId (trait_decl_id, fun_name, rg_id) ->
- let fwd_back_kind =
- match rg_id with
- | None -> "forward"
- | Some rg_id -> "backward " ^ RegionGroupId.to_string rg_id
- in
- trait_decl_id_to_string trait_decl_id
- ^ ", method name (" ^ fwd_back_kind ^ "): " ^ fun_name
+ | TraitMethodId (trait_decl_id, fun_name) ->
+ trait_decl_id_to_string trait_decl_id ^ ", method name: " ^ fun_name
| TraitSelfClauseId -> "trait_self_clause"
let ctx_add (id : id) (name : string) (ctx : extraction_ctx) : extraction_ctx =
@@ -695,8 +673,8 @@ let ctx_get_function (id : fun_id) (ctx : extraction_ctx) : string =
ctx_get (FunId id) ctx
let ctx_get_local_function (id : A.FunDeclId.id) (lp : LoopId.id option)
- (rg : RegionGroupId.id option) (ctx : extraction_ctx) : string =
- ctx_get_function (FromLlbc (FunId (FRegular id), lp, rg)) ctx
+ (ctx : extraction_ctx) : string =
+ ctx_get_function (FromLlbc (FunId (FRegular id), lp)) ctx
let ctx_get_type (id : type_id) (ctx : extraction_ctx) : string =
assert (id <> TTuple);
@@ -734,8 +712,8 @@ let ctx_get_trait_type (id : trait_decl_id) (item_name : string)
ctx_get_trait_item id item_name ctx
let ctx_get_trait_method (id : trait_decl_id) (item_name : string)
- (rg_id : T.RegionGroupId.id option) (ctx : extraction_ctx) : string =
- ctx_get (TraitMethodId (id, item_name, rg_id)) ctx
+ (ctx : extraction_ctx) : string =
+ ctx_get (TraitMethodId (id, item_name)) ctx
let ctx_get_trait_parent_clause (id : trait_decl_id) (clause : trait_clause_id)
(ctx : extraction_ctx) : string =
@@ -1052,63 +1030,28 @@ let assumed_variants () : (assumed_ty * VariantId.id * string) list =
(* No Fuel::Succ on purpose *)
]
-let assumed_llbc_functions () :
- (A.assumed_fun_id * T.RegionGroupId.id option * string) list =
- let rg0 = Some T.RegionGroupId.zero in
- let regular : (A.assumed_fun_id * T.RegionGroupId.id option * string) list =
- match !backend with
- | FStar | Coq | HOL4 ->
- [
- (ArrayIndexShared, None, "array_index_usize");
- (ArrayToSliceShared, None, "array_to_slice");
- (ArrayRepeat, None, "array_repeat");
- (SliceIndexShared, None, "slice_index_usize");
- ]
- | Lean ->
- [
- (ArrayIndexShared, None, "Array.index_usize");
- (ArrayToSliceShared, None, "Array.to_slice");
- (ArrayRepeat, None, "Array.repeat");
- (SliceIndexShared, None, "Slice.index_usize");
- ]
- in
- let mut_funs : (A.assumed_fun_id * T.RegionGroupId.id option * string) list =
- if !Config.return_back_funs then
- match !backend with
- | FStar | Coq | HOL4 ->
- [
- (ArrayIndexMut, None, "array_index_mut_usize");
- (ArrayToSliceMut, None, "array_to_slice_mut");
- (SliceIndexMut, None, "slice_index_mut_usize");
- ]
- | Lean ->
- [
- (ArrayIndexMut, None, "Array.index_mut_usize");
- (ArrayToSliceMut, None, "Array.to_slice_mut");
- (SliceIndexMut, None, "Slice.index_mut_usize");
- ]
- else
- match !backend with
- | FStar | Coq | HOL4 ->
- [
- (ArrayIndexMut, None, "array_index_usize");
- (ArrayIndexMut, rg0, "array_update_usize");
- (ArrayToSliceMut, None, "array_to_slice");
- (ArrayToSliceMut, rg0, "array_from_slice");
- (SliceIndexMut, None, "slice_index_usize");
- (SliceIndexMut, rg0, "slice_update_usize");
- ]
- | Lean ->
- [
- (ArrayIndexMut, None, "Array.index_usize");
- (ArrayIndexMut, rg0, "Array.update_usize");
- (ArrayToSliceMut, None, "Array.to_slice");
- (ArrayToSliceMut, rg0, "Array.from_slice");
- (SliceIndexMut, None, "Slice.index_usize");
- (SliceIndexMut, rg0, "Slice.update_usize");
- ]
- in
- regular @ mut_funs
+let assumed_llbc_functions () : (A.assumed_fun_id * string) list =
+ match !backend with
+ | FStar | Coq | HOL4 ->
+ [
+ (ArrayIndexShared, "array_index_usize");
+ (ArrayIndexMut, "array_index_mut_usize");
+ (ArrayToSliceShared, "array_to_slice");
+ (ArrayToSliceMut, "array_to_slice_mut");
+ (ArrayRepeat, "array_repeat");
+ (SliceIndexShared, "slice_index_usize");
+ (SliceIndexMut, "slice_index_mut_usize");
+ ]
+ | Lean ->
+ [
+ (ArrayIndexShared, "Array.index_usize");
+ (ArrayIndexMut, "Array.index_mut_usize");
+ (ArrayToSliceShared, "Array.to_slice");
+ (ArrayToSliceMut, "Array.to_slice_mut");
+ (ArrayRepeat, "Array.repeat");
+ (SliceIndexShared, "Slice.index_usize");
+ (SliceIndexMut, "Slice.index_mut_usize");
+ ]
let assumed_pure_functions () : (pure_assumed_fun_id * string) list =
match !backend with
@@ -1200,8 +1143,7 @@ let initialize_names_maps () : names_maps =
in
let assumed_functions =
List.map
- (fun (fid, rg, name) ->
- (FromLlbc (Pure.FunId (FAssumed fid), None, rg), name))
+ (fun (fid, name) -> (FromLlbc (Pure.FunId (FAssumed fid), None), name))
init.assumed_llbc_functions
@ List.map (fun (fid, name) -> (Pure fid, name)) init.assumed_pure_functions
in
@@ -1444,61 +1386,12 @@ let default_fun_loop_suffix (num_loops : int) (loop_id : LoopId.id option) :
If this function admits only one loop, we omit it. *)
if num_loops = 1 then "_loop" else "_loop" ^ LoopId.to_string loop_id
-(** A helper function: generates a function suffix from a region group
- information.
+(** A helper function: generates a function suffix.
TODO: move all those helpers.
*)
-let default_fun_suffix (num_loops : int) (loop_id : LoopId.id option)
- (num_region_groups : int) (rg : region_group_info option)
- ((keep_fwd, num_backs) : bool * int) : string =
- let lp_suff = default_fun_loop_suffix num_loops loop_id in
-
- (* There are several cases:
- - [rg] is [Some]: this is a forward function:
- - we add "_fwd"
- - [rg] is [None]: this is a backward function:
- - this function has one extracted backward function:
- - if the forward function has been filtered, we add nothing:
- the forward function is useless, so the unique backward function
- takes its place, in a way (in effect, we "merge" the forward
- and the backward functions).
- - otherwise we add "_back"
- - this function has several backward functions: we add "_back" and an
- additional suffix to identify the precise backward function
- Note that we always add a suffix (in case there are no region groups,
- we could not add the "_fwd" suffix) to prevent name clashes between
- definitions (in particular between type and function definitions).
- *)
- let rg_suff =
- (* TODO: make all the backends match what is done for Lean *)
- match rg with
- | None ->
- if
- (* In order to avoid name conflicts:
- * - if the forward is eliminated, we add the suffix "_fwd" (it won't be used)
- * - otherwise, no suffix (because the backward functions will have a suffix)
- *)
- num_backs = 1 && not keep_fwd
- then "_fwd"
- else ""
- | Some rg ->
- assert (num_region_groups > 0 && num_backs > 0);
- if num_backs = 1 then
- (* Exactly one backward function *)
- if not keep_fwd then "" else "_back"
- else if
- (* Several region groups/backward functions:
- - if all the regions in the group have names, we use those names
- - otherwise we use an index
- *)
- List.for_all Option.is_some rg.region_names
- then
- (* Concatenate the region names *)
- "_back" ^ String.concat "" (List.map Option.get rg.region_names)
- else (* Use the region index *)
- "_back" ^ RegionGroupId.to_string rg.id
- in
- lp_suff ^ rg_suff
+let default_fun_suffix (num_loops : int) (loop_id : LoopId.id option) : string =
+ (* We only generate a suffix for the functions we generate from the loops *)
+ default_fun_loop_suffix num_loops loop_id
(** Compute the name of a regular (non-assumed) function.
@@ -1508,24 +1401,13 @@ let default_fun_suffix (num_loops : int) (loop_id : LoopId.id option)
indices to derive unique names for the loops for instance - if there is
exactly one loop, we don't need to use indices)
- loop id (if pertinent)
- - number of region groups
- - region group information in case of a backward function
- ([None] if forward function)
- - pair:
- - do we generate the forward function (it may have been filtered)?
- - the number of *extracted backward functions* (same comment as for
- the number of loops)
- The number of extracted backward functions if not necessarily
- equal to the number of region groups, because we may have
- filtered some of them.
TODO: use the fun id for the assumed functions.
*)
let ctx_compute_fun_name (ctx : extraction_ctx) (fname : llbc_name)
- (num_loops : int) (loop_id : LoopId.id option) (num_rgs : int)
- (rg : region_group_info option) (filter_info : bool * int) : string =
+ (num_loops : int) (loop_id : LoopId.id option) : string =
let fname = ctx_compute_fun_name_no_suffix ctx fname in
(* Compute the suffix *)
- let suffix = default_fun_suffix num_loops loop_id num_rgs rg filter_info in
+ let suffix = default_fun_suffix num_loops loop_id in
(* Concatenate *)
fname ^ suffix
@@ -1999,61 +1881,26 @@ let ctx_add_global_decl_and_body (def : A.global_decl) (ctx : extraction_ctx) :
| None ->
(* Not the case: "standard" registration *)
let name = ctx_compute_global_name ctx def.name in
- let body = FunId (FromLlbc (FunId (FRegular def.body), None, None)) in
+ let body = FunId (FromLlbc (FunId (FRegular def.body), None)) in
let ctx = ctx_add decl (name ^ "_c") ctx in
let ctx = ctx_add body (name ^ "_body") ctx in
ctx
-let ctx_compute_fun_name (trans_group : pure_fun_translation) (def : fun_decl)
- (ctx : extraction_ctx) : string =
- (* Lookup the LLBC def to compute the region group information *)
- let def_id = def.def_id in
- let llbc_def = A.FunDeclId.Map.find def_id ctx.trans_ctx.fun_ctx.fun_decls in
- let sg = llbc_def.signature in
- let regions_hierarchy =
- LlbcAstUtils.FunIdMap.find (FRegular def_id)
- ctx.trans_ctx.fun_ctx.regions_hierarchies
- in
- let num_rgs = List.length regions_hierarchy in
- let { keep_fwd; fwd = _; backs } = trans_group in
- let num_backs = List.length backs in
- let rg_info =
- match def.back_id with
- | None -> None
- | Some rg_id ->
- let rg = T.RegionGroupId.nth regions_hierarchy rg_id in
- let region_names =
- List.map
- (fun rid -> (T.RegionVarId.nth sg.generics.regions rid).name)
- rg.regions
- in
- Some { id = rg_id; region_names }
- in
+let ctx_compute_fun_name (def : fun_decl) (ctx : extraction_ctx) : string =
(* Add the function name *)
- ctx_compute_fun_name ctx def.llbc_name def.num_loops def.loop_id num_rgs
- rg_info (keep_fwd, num_backs)
+ ctx_compute_fun_name ctx def.llbc_name def.num_loops def.loop_id
(* TODO: move to Extract *)
-let ctx_add_fun_decl (trans_group : pure_fun_translation) (def : fun_decl)
- (ctx : extraction_ctx) : extraction_ctx =
+let ctx_add_fun_decl (def : fun_decl) (ctx : extraction_ctx) : extraction_ctx =
(* Sanity check: the function should not be a global body - those are handled
* separately *)
assert (not def.is_global_decl_body);
(* Lookup the LLBC def to compute the region group information *)
let def_id = def.def_id in
- let { keep_fwd; fwd = _; backs } = trans_group in
- let num_backs = List.length backs in
(* Add the function name *)
- let def_name = ctx_compute_fun_name trans_group def ctx in
- let fun_id = (Pure.FunId (FRegular def_id), def.loop_id, def.back_id) in
- let ctx = ctx_add (FunId (FromLlbc fun_id)) def_name ctx in
- (* Add the name info *)
- {
- ctx with
- fun_name_info =
- PureUtils.RegularFunIdMap.add fun_id { keep_fwd; num_backs }
- ctx.fun_name_info;
- }
+ let def_name = ctx_compute_fun_name def ctx in
+ let fun_id = (Pure.FunId (FRegular def_id), def.loop_id) in
+ ctx_add (FunId (FromLlbc fun_id)) def_name ctx
let ctx_compute_type_decl_name (ctx : extraction_ctx) (def : type_decl) : string
=
diff --git a/compiler/ExtractBuiltin.ml b/compiler/ExtractBuiltin.ml
index ee8d4831..88de31fe 100644
--- a/compiler/ExtractBuiltin.ml
+++ b/compiler/ExtractBuiltin.ml
@@ -213,11 +213,7 @@ let mk_builtin_types_map () =
let builtin_types_map = mk_memoized mk_builtin_types_map
-type builtin_fun_info = {
- rg : Types.RegionGroupId.id option;
- extract_name : string;
-}
-[@@deriving show]
+type builtin_fun_info = { extract_name : string } [@@deriving show]
(** The assumed functions.
@@ -225,21 +221,11 @@ type builtin_fun_info = {
parameters. For instance, in the case of the `Vec` functions, there is
a type parameter for the allocator to use, which we want to filter.
*)
-let builtin_funs () : (pattern * bool list option * builtin_fun_info list) list
- =
- let rg0 = Some Types.RegionGroupId.zero in
+let builtin_funs () : (pattern * bool list option * builtin_fun_info) list =
(* Small utility *)
let mk_fun (rust_name : string) (extract_name : string option)
- (filter : bool list option) (with_back : bool) (back_no_suffix : bool) :
- pattern * bool list option * builtin_fun_info list =
- (* [back_no_suffix] is used to control whether the backward function should
- have the suffix "_back" or not (if not, then the forward function has the
- prefix "_fwd", and is filtered anyway). This is pertinent only if we split
- the fwd/back functions. *)
- let back_no_suffix = back_no_suffix && not !Config.return_back_funs in
- (* Same for the [with_back] option: this is pertinent only if we split
- the fwd/back functions *)
- let with_back = with_back && not !Config.return_back_funs in
+ (filter : bool list option) :
+ pattern * bool list option * builtin_fun_info =
let rust_name =
try parse_pattern rust_name
with Failure _ ->
@@ -251,68 +237,51 @@ let builtin_funs () : (pattern * bool list option * builtin_fun_info list) list
| Some name -> split_on_separator name
in
let basename = flatten_name extract_name in
- let fwd_suffix = if with_back && back_no_suffix then "_fwd" else "" in
- let fwd = [ { rg = None; extract_name = basename ^ fwd_suffix } ] in
- let back_suffix = if with_back && back_no_suffix then "" else "_back" in
- let back =
- if with_back then [ { rg = rg0; extract_name = basename ^ back_suffix } ]
- else []
- in
- (rust_name, filter, fwd @ back)
+ let f = { extract_name = basename } in
+ (rust_name, filter, f)
in
[
- mk_fun "core::mem::replace" None None true false;
+ mk_fun "core::mem::replace" None None;
mk_fun "core::slice::{[@T]}::len"
(Some (backend_choice "slice::len" "Slice::len"))
- None true false;
+ None;
mk_fun "alloc::vec::{alloc::vec::Vec<@T, alloc::alloc::Global>}::new"
- (Some "alloc::vec::Vec::new") None false false;
+ (Some "alloc::vec::Vec::new") None;
mk_fun "alloc::vec::{alloc::vec::Vec<@T, @A>}::push" None
- (Some [ true; false ])
- true true;
+ (Some [ true; false ]);
mk_fun "alloc::vec::{alloc::vec::Vec<@T, @A>}::insert" None
- (Some [ true; false ])
- true true;
+ (Some [ true; false ]);
mk_fun "alloc::vec::{alloc::vec::Vec<@T, @A>}::len" None
- (Some [ true; false ])
- true false;
+ (Some [ true; false ]);
mk_fun "alloc::vec::{alloc::vec::Vec<@T, @A>}::index" None
- (Some [ true; true; false ])
- true false;
+ (Some [ true; true; false ]);
mk_fun "alloc::vec::{alloc::vec::Vec<@T, @A>}::index_mut" None
- (Some [ true; true; false ])
- true false;
- mk_fun "alloc::boxed::{Box<@T>}::deref" None
- (Some [ true; false ])
- true false;
- mk_fun "alloc::boxed::{Box<@T>}::deref_mut" None
- (Some [ true; false ])
- true false;
- mk_fun "core::slice::index::{[@T]}::index" None None true false;
- mk_fun "core::slice::index::{[@T]}::index_mut" None None true false;
- mk_fun "core::array::{[@T; @C]}::index" None None true false;
- mk_fun "core::array::{[@T; @C]}::index_mut" None None true false;
+ (Some [ true; true; false ]);
+ mk_fun "alloc::boxed::{Box<@T>}::deref" None (Some [ true; false ]);
+ mk_fun "alloc::boxed::{Box<@T>}::deref_mut" None (Some [ true; false ]);
+ mk_fun "core::slice::index::{[@T]}::index" None None;
+ mk_fun "core::slice::index::{[@T]}::index_mut" None None;
+ mk_fun "core::array::{[@T; @C]}::index" None None;
+ mk_fun "core::array::{[@T; @C]}::index_mut" None None;
mk_fun "core::slice::index::{core::ops::range::Range<usize>}::get"
- (Some "core::slice::index::RangeUsize::get") None true false;
+ (Some "core::slice::index::RangeUsize::get") None;
mk_fun "core::slice::index::{core::ops::range::Range<usize>}::get_mut"
- (Some "core::slice::index::RangeUsize::get_mut") None true false;
+ (Some "core::slice::index::RangeUsize::get_mut") None;
mk_fun "core::slice::index::{core::ops::range::Range<usize>}::index"
- (Some "core::slice::index::RangeUsize::index") None true false;
+ (Some "core::slice::index::RangeUsize::index") None;
mk_fun "core::slice::index::{core::ops::range::Range<usize>}::index_mut"
- (Some "core::slice::index::RangeUsize::index_mut") None true false;
+ (Some "core::slice::index::RangeUsize::index_mut") None;
mk_fun "core::slice::index::{core::ops::range::Range<usize>}::get_unchecked"
- (Some "core::slice::index::RangeUsize::get_unchecked") None false false;
+ (Some "core::slice::index::RangeUsize::get_unchecked") None;
mk_fun
"core::slice::index::{core::ops::range::Range<usize>}::get_unchecked_mut"
- (Some "core::slice::index::RangeUsize::get_unchecked_mut") None false
- false;
- mk_fun "core::slice::index::{usize}::get" None None true false;
- mk_fun "core::slice::index::{usize}::get_mut" None None true false;
- mk_fun "core::slice::index::{usize}::get_unchecked" None None false false;
- mk_fun "core::slice::index::{usize}::get_unchecked_mut" None None false
- false;
- mk_fun "core::slice::index::{usize}::index" None None true false;
- mk_fun "core::slice::index::{usize}::index_mut" None None true false;
+ (Some "core::slice::index::RangeUsize::get_unchecked_mut") None;
+ mk_fun "core::slice::index::{usize}::get" None None;
+ mk_fun "core::slice::index::{usize}::get_mut" None None;
+ mk_fun "core::slice::index::{usize}::get_unchecked" None None;
+ mk_fun "core::slice::index::{usize}::get_unchecked_mut" None None;
+ mk_fun "core::slice::index::{usize}::index" None None;
+ mk_fun "core::slice::index::{usize}::index_mut" None None;
]
let mk_builtin_funs_map () =
@@ -407,15 +376,14 @@ type builtin_trait_decl_info = {
- a Rust name
- an extraction name
- a list of clauses *)
- methods : (string * builtin_fun_info list) list;
+ methods : (string * builtin_fun_info) list;
}
[@@deriving show]
let builtin_trait_decls_info () =
- let rg0 = Some Types.RegionGroupId.zero in
let mk_trait (rust_name : string) ?(extract_name : string option = None)
?(parent_clauses : string list = []) ?(types : string list = [])
- ?(methods : (string * bool) list = []) () : builtin_trait_decl_info =
+ ?(methods : string list = []) () : builtin_trait_decl_info =
let rust_name = parse_pattern rust_name in
let extract_name =
match extract_name with
@@ -443,22 +411,14 @@ let builtin_trait_decls_info () =
List.map mk_type types
in
let methods =
- let mk_method (item_name, with_back) =
+ let mk_method item_name =
(* TODO: factor out with builtin_funs_info *)
let basename =
if !record_fields_short_names then item_name
else extract_name ^ "_" ^ item_name
in
- let back_no_suffix = false in
- let fwd_suffix = if with_back && back_no_suffix then "_fwd" else "" in
- let fwd = [ { rg = None; extract_name = basename ^ fwd_suffix } ] in
- let back_suffix = if with_back && back_no_suffix then "" else "_back" in
- let back =
- if with_back then
- [ { rg = rg0; extract_name = basename ^ back_suffix } ]
- else []
- in
- (item_name, fwd @ back)
+ let fwd = { extract_name = basename } in
+ (item_name, fwd)
in
List.map mk_method methods
in
@@ -474,21 +434,17 @@ let builtin_trait_decls_info () =
in
[
(* Deref *)
- mk_trait "core::ops::deref::Deref" ~types:[ "Target" ]
- ~methods:[ ("deref", true) ]
+ mk_trait "core::ops::deref::Deref" ~types:[ "Target" ] ~methods:[ "deref" ]
();
(* DerefMut *)
mk_trait "core::ops::deref::DerefMut" ~parent_clauses:[ "derefInst" ]
- ~methods:[ ("deref_mut", true) ]
- ();
+ ~methods:[ "deref_mut" ] ();
(* Index *)
- mk_trait "core::ops::index::Index" ~types:[ "Output" ]
- ~methods:[ ("index", true) ]
+ mk_trait "core::ops::index::Index" ~types:[ "Output" ] ~methods:[ "index" ]
();
(* IndexMut *)
mk_trait "core::ops::index::IndexMut" ~parent_clauses:[ "indexInst" ]
- ~methods:[ ("index_mut", true) ]
- ();
+ ~methods:[ "index_mut" ] ();
(* Sealed *)
mk_trait "core::slice::index::private_slice_index::Sealed" ();
(* SliceIndex *)
@@ -496,12 +452,12 @@ let builtin_trait_decls_info () =
~types:[ "Output" ]
~methods:
[
- ("get", true);
- ("get_mut", true);
- ("get_unchecked", false);
- ("get_unchecked_mut", false);
- ("index", true);
- ("index_mut", true);
+ "get";
+ "get_mut";
+ "get_unchecked";
+ "get_unchecked_mut";
+ "index";
+ "index_mut";
]
();
]
diff --git a/compiler/ExtractTypes.ml b/compiler/ExtractTypes.ml
index a3dbf3cc..05b71b9f 100644
--- a/compiler/ExtractTypes.ml
+++ b/compiler/ExtractTypes.ml
@@ -272,7 +272,7 @@ let start_fun_decl_group (ctx : extraction_ctx) (fmt : F.formatter)
if is_single_opaque_fun_decl_group dg then ()
else
let compute_fun_def_name (def : Pure.fun_decl) : string =
- ctx_get_local_function def.def_id def.loop_id def.back_id ctx ^ "_def"
+ ctx_get_local_function def.def_id def.loop_id ctx ^ "_def"
in
let names = List.map compute_fun_def_name dg in
(* Add a break before *)
diff --git a/compiler/Main.ml b/compiler/Main.ml
index 4a2d01dc..3f5e62ad 100644
--- a/compiler/Main.ml
+++ b/compiler/Main.ml
@@ -72,12 +72,6 @@ let () =
Arg.Symbol (backend_names, set_backend),
" Specify the target backend" );
("-dest", Arg.Set_string dest_dir, " Specify the output directory");
- ( "-no-filter-useless-calls",
- Arg.Clear filter_useless_monadic_calls,
- " Do not filter the useless function calls" );
- ( "-no-filter-useless-funs",
- Arg.Clear filter_useless_functions,
- " Do not filter the useless forward/backward functions" );
( "-test-units",
Arg.Set test_unit_functions,
" Test the unit functions with the concrete (i.e., not symbolic) \
@@ -120,9 +114,6 @@ let () =
" Generate a default lakefile.lean (Lean only)" );
("-print-llbc", Arg.Set print_llbc, " Print the imported LLBC");
("-k", Arg.Clear fail_hard, " Do not fail hard in case of error");
- ( "-split-fwd-back",
- Arg.Clear return_back_funs,
- " Split the forward and backward functions." );
( "-tuple-nested-proj",
Arg.Set use_nested_tuple_projectors,
" Use nested projectors for tuples (e.g., (0, 1).snd.fst instead of \
diff --git a/compiler/PrintPure.ml b/compiler/PrintPure.ml
index 66475d02..21ca7f08 100644
--- a/compiler/PrintPure.ml
+++ b/compiler/PrintPure.ml
@@ -462,21 +462,13 @@ let inst_fun_sig_to_string (env : fmt_env) (sg : inst_fun_sig) : string =
let all_types = List.append inputs [ output ] in
String.concat " -> " all_types
-let fun_suffix (lp_id : LoopId.id option) (rg_id : T.RegionGroupId.id option) :
- string =
+let fun_suffix (lp_id : LoopId.id option) : string =
let lp_suff =
match lp_id with
| None -> ""
| Some lp_id -> "^loop" ^ LoopId.to_string lp_id
in
-
- let rg_suff =
- match rg_id with
- | None -> ""
- | Some rg_id -> "@" ^ T.RegionGroupId.to_string rg_id
- in
-
- lp_suff ^ rg_suff
+ lp_suff
let llbc_assumed_fun_id_to_string (fid : A.assumed_fun_id) : string =
match fid with
@@ -505,7 +497,7 @@ let pure_assumed_fun_id_to_string (fid : pure_assumed_fun_id) : string =
let regular_fun_id_to_string (env : fmt_env) (fun_id : fun_id) : string =
match fun_id with
- | FromLlbc (fid, lp_id, rg_id) ->
+ | FromLlbc (fid, lp_id) ->
let f =
match fid with
| FunId (FRegular fid) -> fun_decl_id_to_string env fid
@@ -513,7 +505,7 @@ let regular_fun_id_to_string (env : fmt_env) (fun_id : fun_id) : string =
| TraitMethod (trait_ref, method_name, _) ->
trait_ref_to_string env true trait_ref ^ "." ^ method_name
in
- f ^ fun_suffix lp_id rg_id
+ f ^ fun_suffix lp_id
| Pure fid -> pure_assumed_fun_id_to_string fid
let unop_to_string (unop : unop) : string =
@@ -746,7 +738,7 @@ and emeta_to_string (env : fmt_env) (meta : emeta) : string =
let fun_decl_to_string (env : fmt_env) (def : fun_decl) : string =
let env = { env with generics = def.signature.generics } in
- let name = def.name ^ fun_suffix def.loop_id def.back_id in
+ let name = def.name ^ fun_suffix def.loop_id in
let signature = fun_sig_to_string env def.signature in
match def.body with
| None -> "val " ^ name ^ " :\n " ^ signature
diff --git a/compiler/Pure.ml b/compiler/Pure.ml
index a879ba37..dd7a4acf 100644
--- a/compiler/Pure.ml
+++ b/compiler/Pure.ml
@@ -560,8 +560,7 @@ type fun_id_or_trait_method_ref =
[@@deriving show, ord]
(** A function id for a non-assumed function *)
-type regular_fun_id =
- fun_id_or_trait_method_ref * LoopId.id option * RegionGroupId.id option
+type regular_fun_id = fun_id_or_trait_method_ref * LoopId.id option
[@@deriving show, ord]
(** A function identifier *)
@@ -1078,7 +1077,6 @@ type fun_decl = {
*)
loop_id : LoopId.id option;
(** [Some] if this definition was generated for a loop *)
- back_id : RegionGroupId.id option;
llbc_name : llbc_name; (** The original LLBC name. *)
name : string;
(** We use the name only for printing purposes (for debugging):
diff --git a/compiler/PureMicroPasses.ml b/compiler/PureMicroPasses.ml
index ec64df21..04bc90d7 100644
--- a/compiler/PureMicroPasses.ml
+++ b/compiler/PureMicroPasses.ml
@@ -925,156 +925,9 @@ let inline_useless_var_reassignments (ctx : trans_ctx) ~(inline_named : bool)
in
{ def with body = Some body }
-(** For the cases where we split the forward/backward functions.
-
- Given a forward or backward function call, is there, for every execution
- path, a child backward function called later with exactly the same input
- list prefix. We use this to filter useless function calls: if there are
- such child calls, we can remove this one (in case its outputs are not
- used).
- We do this check because we can't simply remove function calls whose
- outputs are not used, as they might fail. However, if a function fails,
- its children backward functions then fail on the same inputs (ignoring
- the additional inputs those receive).
-
- For instance, if we have:
- {[
- fn f<'a>(x : &'a mut T);
- ]}
-
- We often have things like this in the synthesized code:
- {[
- _ <-- f@fwd x;
- ...
- nx <-- f@back'a x y;
- ...
- ]}
-
- If [f@back'a x y] fails, then necessarily [f@fwd x] also fails.
- In this situation, we can remove the call [f@fwd x].
- *)
-let expression_contains_child_call_in_all_paths (ctx : trans_ctx)
- (id0 : fun_id_or_trait_method_ref) (lp_id0 : LoopId.id option)
- (rg_id0 : T.RegionGroupId.id option) (generics0 : generic_args)
- (args0 : texpression list) (e : texpression) : bool =
- let check_call (fun_id1 : fun_or_op_id) (generics1 : generic_args)
- (args1 : texpression list) : bool =
- (* Check the fun_ids, to see if call1's function is a child of call0's function *)
- match fun_id1 with
- | Fun (FromLlbc (id1, lp_id1, rg_id1)) ->
- (* Both are "regular" calls: check if they come from the same rust function *)
- if id0 = id1 && lp_id0 = lp_id1 then
- (* Same rust functions: check the regions hierarchy *)
- let call1_is_child =
- match (rg_id0, rg_id1) with
- | None, _ ->
- (* The function used in call0 is the forward function: the one
- * used in call1 is necessarily a child *)
- true
- | Some _, None ->
- (* Opposite of previous case *)
- false
- | Some rg_id0, Some rg_id1 ->
- if rg_id0 = rg_id1 then true
- else
- (* We need to use the regions hierarchy *)
- let regions_hierarchy =
- let id0 =
- match id0 with
- | FunId fun_id -> fun_id
- | TraitMethod (_, _, fun_decl_id) -> FRegular fun_decl_id
- in
- LlbcAstUtils.FunIdMap.find id0
- ctx.fun_ctx.regions_hierarchies
- in
- (* Compute the set of ancestors of the function in call1 *)
- let call1_ancestors =
- LlbcAstUtils.list_ancestor_region_groups regions_hierarchy
- rg_id1
- in
- (* Check if the function used in call0 is inside *)
- T.RegionGroupId.Set.mem rg_id0 call1_ancestors
- in
- (* If call1 is a child, then we need to check if the input arguments
- * used in call0 are a prefix of the input arguments used in call1
- * (note call1 being a child, it will likely consume strictly more
- * given back values).
- * *)
- if call1_is_child then
- let call1_args =
- Collections.List.prefix (List.length args0) args1
- in
- let args = List.combine args0 call1_args in
- (* Note that the input values are expressions, *which may contain
- * meta-values* (which we need to ignore). *)
- let input_eq (v0, v1) =
- PureUtils.remove_meta v0 = PureUtils.remove_meta v1
- in
- (* Compare the generics and the prefix of the input arguments *)
- generics0 = generics1 && List.for_all input_eq args
- else (* Not a child *)
- false
- else (* Not the same function *)
- false
- | _ -> false
- in
-
- let visitor =
- object (self)
- inherit [_] reduce_expression
- method zero _ = false
- method plus b0 b1 _ = b0 () && b1 ()
-
- method! visit_texpression env e =
- match e.e with
- | Var _ | CVar _ | Const _ -> fun _ -> false
- | StructUpdate _ ->
- (* There shouldn't be monadic calls in structure updates - also
- note that by returning [false] we are conservative: we might
- *prevent* possible optimisations (i.e., filtering some function
- calls), which is sound. *)
- fun _ -> false
- | Let (_, _, re, e) -> (
- match opt_destruct_function_call re with
- | None -> fun () -> self#visit_texpression env e ()
- | Some (func1, generics1, args1) ->
- let call_is_child = check_call func1 generics1 args1 in
- if call_is_child then fun () -> true
- else fun () -> self#visit_texpression env e ())
- | Lambda (_, e) -> self#visit_texpression env e
- | App _ -> (
- fun () ->
- match opt_destruct_function_call e with
- | Some (func1, tys1, args1) -> check_call func1 tys1 args1
- | None -> false)
- | Qualif _ ->
- (* Note that this case includes functions without arguments *)
- fun () -> false
- | Meta (_, e) -> self#visit_texpression env e
- | Loop loop ->
- (* We only visit the *function end* *)
- self#visit_texpression env loop.fun_end
- | Switch (_, body) -> self#visit_switch_body env body
-
- method! visit_switch_body env body =
- match body with
- | If (e1, e2) ->
- fun () ->
- self#visit_texpression env e1 ()
- && self#visit_texpression env e2 ()
- | Match branches ->
- fun () ->
- List.for_all
- (fun br -> self#visit_texpression env br.branch ())
- branches
- end
- in
- visitor#visit_texpression () e ()
-
(** Filter the useless assignments (removes the useless variables, filters
the function calls) *)
-let filter_useless (filter_monadic_calls : bool) (ctx : trans_ctx)
- (def : fun_decl) : fun_decl =
+let filter_useless (_ctx : trans_ctx) (def : fun_decl) : fun_decl =
(* We first need a transformation on *left-values*, which filters the useless
* variables and tells us whether the value contains any variable which has
* not been replaced by [_] (in which case we need to keep the assignment,
@@ -1166,30 +1019,8 @@ let filter_useless (filter_monadic_calls : bool) (ctx : trans_ctx)
if not monadic then
(* Not a monadic let-binding: simple case *)
(e.e, fun _ -> used)
- else
- (* Monadic let-binding: trickier.
- * We can filter if the right-expression is a function call,
- * under some conditions. *)
- match (filter_monadic_calls, opt_destruct_function_call re) with
- | true, Some (Fun (FromLlbc (fid, lp_id, rg_id)), tys, args) ->
- (* If we split the forward/backward functions.
-
- We need to check if there is a child call - see
- the comments for:
- [expression_contains_child_call_in_all_paths] *)
- if not !Config.return_back_funs then
- let has_child_call =
- expression_contains_child_call_in_all_paths ctx fid
- lp_id rg_id tys args e
- in
- if has_child_call then (* Filter *)
- (e.e, fun _ -> used)
- else (* No child call: don't filter *)
- dont_filter ()
- else dont_filter ()
- | _ ->
- (* Not an LLBC function call or not allowed to filter: we can't filter *)
- dont_filter ()
+ else (* Monadic let-binding: can't filter *)
+ dont_filter ()
else (* There are used variables: don't filter *)
dont_filter ()
| Loop loop ->
@@ -1442,22 +1273,6 @@ let simplify_aggregates (ctx : trans_ctx) (def : fun_decl) : fun_decl =
let body = { body with body = body_exp } in
{ def with body = Some body }
-(** Return [None] if the function is a backward function with no outputs (so
- that we eliminate the definition which is useless).
-
- Note that the calls to such functions are filtered when translating from
- symbolic to pure. Here, we remove the definitions altogether, because they
- are now useless
- *)
-let filter_if_backward_with_no_outputs (def : fun_decl) : fun_decl option =
- if
- !Config.filter_useless_functions
- && Option.is_some def.back_id
- && def.signature.output = mk_result_ty mk_unit_ty
- || def.signature.output = mk_unit_ty
- then None
- else Some def
-
(** Retrieve the loop definitions from the function definition.
{!SymbolicToPure} generates an AST in which the loop bodies are part of
@@ -1530,14 +1345,7 @@ let decompose_loops (_ctx : trans_ctx) (def : fun_decl) :
info.num_inputs_with_fuel_no_state
info.num_inputs_with_fuel_with_state
in
- let back_inputs =
- if !Config.return_back_funs then []
- else
- snd
- (Collections.List.split_at fun_sig.inputs
- info.num_inputs_with_fuel_with_state)
- in
- List.concat [ fuel; fwd_inputs; fwd_state; back_inputs ]
+ List.concat [ fuel; fwd_inputs; fwd_state ]
in
let output = loop.output_ty in
@@ -1618,7 +1426,6 @@ let decompose_loops (_ctx : trans_ctx) (def : fun_decl) :
kind = def.kind;
num_loops;
loop_id = Some loop.loop_id;
- back_id = def.back_id;
llbc_name = def.llbc_name;
name = def.name;
signature = loop_sig;
@@ -1640,35 +1447,6 @@ let decompose_loops (_ctx : trans_ctx) (def : fun_decl) :
let loops = List.map snd (LoopId.Map.bindings !loops) in
(def, loops)
-(** Return [false] if the forward function is useless and should be filtered.
-
- - a forward function with no output (comes from a Rust function with
- unit return type)
- - the function has mutable borrows as inputs (which is materialized
- by the fact we generated backward functions which were not filtered).
-
- In such situation, every call to the Rust function will be translated to:
- - a call to the forward function which returns nothing
- - calls to the backward functions
- As a failing backward function implies the forward function also fails,
- we can filter the calls to the forward function, which thus becomes
- useless.
- In such situation, we can remove the forward function definition
- altogether.
- *)
-let keep_forward (fwd : fun_and_loops) (backs : fun_and_loops list) : bool =
- (* The question of filtering the forward functions arises only if we split
- the forward/backward functions *)
- if !Config.return_back_funs then true
- else if
- (* Note that at this point, the output types are no longer seen as tuples:
- * they should be lists of length 1. *)
- !Config.filter_useless_functions
- && fwd.f.signature.output = mk_result_ty mk_unit_ty
- && backs <> []
- then false
- else true
-
(** Convert the unit variables to [()] if they are used as right-values or
[_] if they are used as left values in patterns. *)
let unit_vars_to_unit (def : fun_decl) : fun_decl =
@@ -1724,19 +1502,17 @@ let eliminate_box_functions (_ctx : trans_ctx) (def : fun_decl) : fun_decl =
* could have: [box_new f x])
* *)
match fun_id with
- | Fun (FromLlbc (FunId (FAssumed aid), _lp_id, rg_id)) -> (
- match (aid, rg_id) with
- | BoxNew, _ ->
- assert (rg_id = None);
+ | Fun (FromLlbc (FunId (FAssumed aid), _lp_id)) -> (
+ match aid with
+ | BoxNew ->
let arg, args = Collections.List.pop args in
mk_apps arg args
- | BoxFree, _ ->
+ | BoxFree ->
assert (args = []);
mk_unit_rvalue
- | ( ( SliceIndexShared | SliceIndexMut | ArrayIndexShared
- | ArrayIndexMut | ArrayToSliceShared | ArrayToSliceMut
- | ArrayRepeat ),
- _ ) ->
+ | SliceIndexShared | SliceIndexMut | ArrayIndexShared
+ | ArrayIndexMut | ArrayToSliceShared | ArrayToSliceMut
+ | ArrayRepeat ->
super#visit_texpression env e)
| _ -> super#visit_texpression env e)
| _ -> super#visit_texpression env e
@@ -1989,7 +1765,7 @@ let apply_end_passes_to_def (ctx : trans_ctx) (def : fun_decl) : fun_decl =
(lazy ("eliminate_box_functions:\n\n" ^ fun_decl_to_string ctx def ^ "\n"));
(* Filter the useless variables, assignments, function calls, etc. *)
- let def = filter_useless !Config.filter_useless_monadic_calls ctx def in
+ let def = filter_useless ctx def in
log#ldebug (lazy ("filter_useless:\n\n" ^ fun_decl_to_string ctx def ^ "\n"));
(* Simplify the lets immediately followed by a return.
@@ -2130,16 +1906,7 @@ let filter_loop_inputs (transl : pure_fun_translation list) :
*)
let all_decls =
List.concat
- (List.concat
- (List.concat
- (List.map
- (fun { fwd; backs; _ } ->
- [ fwd.f :: fwd.loops ]
- :: List.map
- (fun { f = back; loops = loops_back } ->
- [ back :: loops_back ])
- backs)
- transl)))
+ (List.concat (List.map (fun { f; loops } -> [ f :: loops ]) transl))
in
let subgroups = ReorderDecls.group_reorder_fun_decls all_decls in
@@ -2207,7 +1974,7 @@ let filter_loop_inputs (transl : pure_fun_translation list) :
match e_app.e with
| Qualif qualif -> (
match qualif.id with
- | FunOrOp (Fun (FromLlbc (FunId fun_id', loop_id', _))) ->
+ | FunOrOp (Fun (FromLlbc (FunId fun_id', loop_id'))) ->
if (fun_id', loop_id') = fun_id then (
(* For each argument, check if it is exactly the original
input parameter. Note that there shouldn't be partial
@@ -2357,8 +2124,7 @@ let filter_loop_inputs (transl : pure_fun_translation list) :
match e_app.e with
| Qualif qualif -> (
match qualif.id with
- | FunOrOp (Fun (FromLlbc (FunId fun_id, loop_id, _)))
- -> (
+ | FunOrOp (Fun (FromLlbc (FunId fun_id, loop_id))) -> (
match
FunLoopIdMap.find_opt (fun_id, loop_id) !used_map
with
@@ -2400,13 +2166,8 @@ let filter_loop_inputs (transl : pure_fun_translation list) :
in
let transl =
List.map
- (fun trans ->
- let filter_fun_and_loops f =
- { f = filter_in_one f.f; loops = List.map filter_in_one f.loops }
- in
- let fwd = filter_fun_and_loops trans.fwd in
- let backs = List.map filter_fun_and_loops trans.backs in
- { trans with fwd; backs })
+ (fun f ->
+ { f = filter_in_one f.f; loops = List.map filter_in_one f.loops })
transl
in
@@ -2420,18 +2181,11 @@ let filter_loop_inputs (transl : pure_fun_translation list) :
it thus returns the pair: (function def, loop defs). See {!decompose_loops}
for more information.
- Will return [None] if the function is a backward function with no outputs.
-
[ctx]: used only for printing.
*)
-let apply_passes_to_def (ctx : trans_ctx) (def : fun_decl) :
- fun_and_loops option =
+let apply_passes_to_def (ctx : trans_ctx) (def : fun_decl) : fun_and_loops =
(* Debug *)
- log#ldebug
- (lazy
- ("PureMicroPasses.apply_passes_to_def: " ^ def.name ^ " ("
- ^ Print.option_to_string T.RegionGroupId.to_string def.back_id
- ^ ")"));
+ log#ldebug (lazy ("PureMicroPasses.apply_passes_to_def: " ^ def.name));
log#ldebug (lazy ("original decl:\n\n" ^ fun_decl_to_string ctx def ^ "\n"));
@@ -2451,29 +2205,13 @@ let apply_passes_to_def (ctx : trans_ctx) (def : fun_decl) :
let def = remove_meta def in
log#ldebug (lazy ("remove_meta:\n\n" ^ fun_decl_to_string ctx def ^ "\n"));
- (* Remove the backward functions with no outputs.
+ (* Extract the loop definitions by removing the {!Loop} node *)
+ let def, loops = decompose_loops ctx def in
- Note that the *calls* to those functions should already have been removed,
- when translating from symbolic to pure. Here, we remove the definitions
- altogether, because they are now useless *)
- let name = def.name ^ PrintPure.fun_suffix def.loop_id def.back_id in
- let opt_def = filter_if_backward_with_no_outputs def in
-
- match opt_def with
- | None ->
- log#ldebug (lazy ("filtered (backward with no outputs): " ^ name ^ "\n"));
- None
- | Some def ->
- log#ldebug
- (lazy ("not filtered (not backward with no outputs): " ^ name ^ "\n"));
-
- (* Extract the loop definitions by removing the {!Loop} node *)
- let def, loops = decompose_loops ctx def in
-
- (* Apply the remaining passes *)
- let f = apply_end_passes_to_def ctx def in
- let loops = List.map (apply_end_passes_to_def ctx) loops in
- Some { f; loops }
+ (* Apply the remaining passes *)
+ let f = apply_end_passes_to_def ctx def in
+ let loops = List.map (apply_end_passes_to_def ctx) loops in
+ { f; loops }
(** Apply the micro-passes to a list of forward/backward translations.
@@ -2489,18 +2227,11 @@ let apply_passes_to_def (ctx : trans_ctx) (def : fun_decl) :
but convenient.
*)
let apply_passes_to_pure_fun_translations (ctx : trans_ctx)
- (transl : (fun_decl * fun_decl list) list) : pure_fun_translation list =
- let apply_to_one (trans : fun_decl * fun_decl list) : pure_fun_translation =
- (* Apply the passes to the individual functions *)
- let fwd, backs = trans in
- let fwd = Option.get (apply_passes_to_def ctx fwd) in
- let backs = List.filter_map (apply_passes_to_def ctx) backs in
- (* Compute whether we need to filter the forward function or not *)
- let keep_fwd = keep_forward fwd backs in
- { keep_fwd; fwd; backs }
- in
-
- let transl = List.map apply_to_one transl in
+ (transl : fun_decl list) : pure_fun_translation list =
+ (* Apply the micro-passes *)
+ let transl = List.map (apply_passes_to_def ctx) transl in
- (* Filter the useless inputs in the loop functions *)
+ (* Filter the useless inputs in the loop functions (loops are initially
+ parameterized by *all* the symbolic values in the context, because
+ they may access any of them). *)
filter_loop_inputs transl
diff --git a/compiler/ReorderDecls.ml b/compiler/ReorderDecls.ml
index 53c94ff4..f5443e03 100644
--- a/compiler/ReorderDecls.ml
+++ b/compiler/ReorderDecls.ml
@@ -5,11 +5,7 @@ open Pure
(** The local logger *)
let log = Logging.reorder_decls_log
-type fun_id = {
- def_id : FunDeclId.id;
- lp_id : LoopId.id option;
- rg_id : T.RegionGroupId.id option;
-}
+type fun_id = { def_id : FunDeclId.id; lp_id : LoopId.id option }
[@@deriving show, ord]
module FunIdOrderedType : OrderedType with type t = fun_id = struct
@@ -43,11 +39,11 @@ let compute_body_fun_deps (e : texpression) : FunIdSet.t =
| FunOrOp (Fun fid) -> (
match fid with
| Pure _ -> ()
- | FromLlbc (fid, lp_id, rg_id) -> (
+ | FromLlbc (fid, lp_id) -> (
match fid with
| FunId (FAssumed _) -> ()
| TraitMethod (_, _, fid) | FunId (FRegular fid) ->
- let id = { def_id = fid; lp_id; rg_id } in
+ let id = { def_id = fid; lp_id } in
ids := FunIdSet.add id !ids))
end
in
@@ -71,7 +67,7 @@ let group_reorder_fun_decls (decls : fun_decl list) :
(bool * fun_decl list) list =
let module IntMap = MakeMap (OrderedInt) in
let get_fun_id (decl : fun_decl) : fun_id =
- { def_id = decl.def_id; lp_id = decl.loop_id; rg_id = decl.back_id }
+ { def_id = decl.def_id; lp_id = decl.loop_id }
in
(* Compute the list/set of identifiers *)
let idl = List.map get_fun_id decls in
diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml
index 3a50e495..2db5f66c 100644
--- a/compiler/SymbolicToPure.ml
+++ b/compiler/SymbolicToPure.ml
@@ -805,11 +805,9 @@ let bs_ctx_register_forward_call (call_id : V.FunCallId.id) (forward : S.call)
that we need to call. This function may be [None] if it has to be ignored
(because it does nothing).
*)
-let bs_ctx_register_backward_call (abs : V.abs) (effect_info : fun_effect_info)
- (call_id : V.FunCallId.id) (back_id : T.RegionGroupId.id)
- (inherited_args : texpression list) (back_args : texpression list)
- (generics : generic_args) (output_ty : ty) (ctx : bs_ctx) :
- bs_ctx * texpression option =
+let bs_ctx_register_backward_call (abs : V.abs) (call_id : V.FunCallId.id)
+ (back_id : T.RegionGroupId.id) (back_args : texpression list) (ctx : bs_ctx)
+ : bs_ctx * texpression option =
(* Insert the abstraction in the call informations *)
let info = V.FunCallId.Map.find call_id ctx.calls in
let calls = V.FunCallId.Map.add call_id info ctx.calls in
@@ -819,29 +817,9 @@ let bs_ctx_register_backward_call (abs : V.abs) (effect_info : fun_effect_info)
let abstractions =
V.AbstractionId.Map.add abs.abs_id (abs, back_args) abstractions
in
- (* Compute the expression corresponding to the function *)
- let func =
- if !Config.return_back_funs then
- (* Lookup the variable introduced for the backward function *)
- RegionGroupId.Map.find back_id (Option.get info.back_funs)
- else
- (* Retrieve the fun_id *)
- let fun_id =
- match info.forward.call_id with
- | S.Fun (fid, _) ->
- let fid = translate_fun_id_or_trait_method_ref ctx fid in
- Fun (FromLlbc (fid, None, Some back_id))
- | S.Unop _ | S.Binop _ -> raise (Failure "Unreachable")
- in
- let args = List.append inherited_args back_args in
- let input_tys = (List.map (fun (x : texpression) -> x.ty)) args in
- let ret_ty =
- if effect_info.can_fail then mk_result_ty output_ty else output_ty
- in
- let func_ty = mk_arrows input_tys ret_ty in
- let func = { id = FunOrOp fun_id; generics } in
- Some { e = Qualif func; ty = func_ty }
- in
+ (* Compute the expression corresponding to the function.
+ We simply lookup the variable introduced for the backward function. *)
+ let func = RegionGroupId.Map.find back_id (Option.get info.back_funs) in
(* Update the context and return *)
({ ctx with calls; abstractions }, func)
@@ -1124,20 +1102,34 @@ let translate_fun_sig_with_regions_hierarchy_to_decomposed
let inputs_no_state =
List.map (fun ty -> (Some "ret", ty)) inputs_no_state
in
- (* In case we merge the forward/backward functions:
- we consider the backward function as stateful and potentially failing
+ (* We consider a backward function as stateful and potentially failing
**only if it has inputs** (for the "potentially failing": if it has
not inputs, we directly evaluate it in the body of the forward function).
+
+ For instance, we do the following:
+ {[
+ // Rust
+ fn push<T, 'a>(v : &mut Vec<T>, x : T) { ... }
+
+ (* Generated code: before doing unit elimination.
+ We return (), as well as the backward function; as the backward
+ function doesn't consume any inputs, it is a value that we compute
+ directly in the body of [push].
+ *)
+ let push T (v : Vec T) (x : T) : Result (() * Vec T) = ...
+
+ (* Generated code: after doing unit elimination, if we simplify the merged
+ fwd/back functions (see below). *)
+ let push T (v : Vec T) (x : T) : Result (Vec T) = ...
+ ]}
*)
let back_effect_info =
- if !Config.return_back_funs then
- let b = inputs_no_state <> [] in
- {
- back_effect_info with
- stateful = back_effect_info.stateful && b;
- can_fail = back_effect_info.can_fail && b;
- }
- else back_effect_info
+ let b = inputs_no_state <> [] in
+ {
+ back_effect_info with
+ stateful = back_effect_info.stateful && b;
+ can_fail = back_effect_info.can_fail && b;
+ }
in
let state =
if back_effect_info.stateful then [ (None, mk_state_ty) ] else []
@@ -1145,8 +1137,7 @@ let translate_fun_sig_with_regions_hierarchy_to_decomposed
let inputs = inputs_no_state @ state in
let output_names, outputs = compute_back_outputs_for_gid gid in
let filter =
- !Config.simplify_merged_fwd_backs
- && !Config.return_back_funs && inputs = [] && outputs = []
+ !Config.simplify_merged_fwd_backs && inputs = [] && outputs = []
in
let info =
{
@@ -1186,7 +1177,7 @@ let translate_fun_sig_with_regions_hierarchy_to_decomposed
}
in
let ignore_output =
- if !Config.return_back_funs && !Config.simplify_merged_fwd_backs then
+ if !Config.simplify_merged_fwd_backs then
ty_is_unit fwd_output
&& List.exists
(fun (info : back_sg_info) -> not info.filter)
@@ -1296,10 +1287,10 @@ let compute_back_tys (dsg : Pure.decomposed_fun_sig)
(subst : (generic_args * trait_instance_id) option) : ty option list =
List.map (Option.map snd) (compute_back_tys_with_info dsg subst)
-(** In case we merge the fwd/back functions: compute the output type of
- a function, from a decomposed signature. *)
+(** Compute the output type of a function, from a decomposed signature
+ (the output type contains the type of the value returned by the forward
+ function as well as the types of the returned backward functions). *)
let compute_output_ty_from_decomposed (dsg : Pure.decomposed_fun_sig) : ty =
- assert !Config.return_back_funs;
(* Compute the arrow types for all the backward functions *)
let back_tys = List.filter_map (fun x -> x) (compute_back_tys dsg None) in
(* Group the forward output and the types of the backward functions *)
@@ -1315,8 +1306,8 @@ let compute_output_ty_from_decomposed (dsg : Pure.decomposed_fun_sig) : ty =
in
mk_output_ty_from_effect_info effect_info output
-let translate_fun_sig_from_decomposed (dsg : Pure.decomposed_fun_sig)
- (gid : RegionGroupId.id option) : fun_sig =
+let translate_fun_sig_from_decomposed (dsg : Pure.decomposed_fun_sig) : fun_sig
+ =
let generics = dsg.generics in
let llbc_generics = dsg.llbc_generics in
let preds = dsg.preds in
@@ -1329,27 +1320,10 @@ let translate_fun_sig_from_decomposed (dsg : Pure.decomposed_fun_sig)
(gid, info.effect_info))
(RegionGroupId.Map.bindings dsg.back_sg))
in
- let mk_output_ty = mk_output_ty_from_effect_info in
let inputs, output =
- (* Two cases depending on whether we split the forward/backward functions or not *)
- if !Config.return_back_funs then (
- assert (gid = None);
- let output = compute_output_ty_from_decomposed dsg in
- let inputs = dsg.fwd_inputs in
- (inputs, output))
- else
- match gid with
- | None ->
- let effect_info = dsg.fwd_info.effect_info in
- let output = mk_output_ty effect_info dsg.fwd_output in
- (dsg.fwd_inputs, output)
- | Some gid ->
- let back_sg = RegionGroupId.Map.find gid dsg.back_sg in
- let effect_info = back_sg.effect_info in
- let inputs = dsg.fwd_inputs @ List.map snd back_sg.inputs in
- let output = mk_simpl_tuple_ty back_sg.outputs in
- let output = mk_output_ty effect_info output in
- (inputs, output)
+ let output = compute_output_ty_from_decomposed dsg in
+ let inputs = dsg.fwd_inputs in
+ (inputs, output)
in
{ generics; llbc_generics; preds; inputs; output; fwd_info; back_effect_info }
@@ -1933,16 +1907,14 @@ and translate_panic (ctx : bs_ctx) : texpression =
*)
match ctx.bid with
| None ->
- if !Config.return_back_funs then
- let back_tys = compute_back_tys ctx.sg None in
- let back_tys = List.filter_map (fun x -> x) back_tys in
- let tys =
- if ctx.sg.fwd_info.ignore_output then back_tys
- else ctx.sg.fwd_output :: back_tys
- in
- let output = mk_simpl_tuple_ty tys in
- mk_output output
- else mk_output ctx.sg.fwd_output
+ let back_tys = compute_back_tys ctx.sg None in
+ let back_tys = List.filter_map (fun x -> x) back_tys in
+ let tys =
+ if ctx.sg.fwd_info.ignore_output then back_tys
+ else ctx.sg.fwd_output :: back_tys
+ in
+ let output = mk_simpl_tuple_ty tys in
+ mk_output output
| Some bid ->
let output =
mk_simpl_tuple_ty (RegionGroupId.Map.find bid ctx.sg.back_sg).outputs
@@ -2063,7 +2035,7 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) :
| S.Fun (fid, call_id) ->
(* Regular function call *)
let fid_t = translate_fun_id_or_trait_method_ref ctx fid in
- let func = Fun (FromLlbc (fid_t, None, None)) in
+ let func = Fun (FromLlbc (fid_t, None)) in
(* Retrieve the effect information about this function (can fail,
* takes a state as input, etc.) *)
let effect_info = get_fun_effect_info ctx fid None None in
@@ -2080,107 +2052,103 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) :
(List.concat [ fuel; args; [ state_var ] ], ctx, Some nstate_var)
else (List.concat [ fuel; args ], ctx, None)
in
- (* If we do not split the forward/backward functions: generate the
- variables for the backward functions returned by the forward
+ (* Generate the variables for the backward functions returned by the forward
function. *)
let ctx, ignore_fwd_output, back_funs_map, back_funs =
- if !Config.return_back_funs then (
- (* We need to compute the signatures of the backward functions. *)
- let sg = Option.get call.sg in
- let decls_ctx = ctx.decls_ctx in
- let dsg =
- translate_fun_sig_with_regions_hierarchy_to_decomposed decls_ctx
- fid call.regions_hierarchy sg
- (List.map (fun _ -> None) sg.inputs)
- in
- log#ldebug
- (lazy ("dsg.generics:\n" ^ show_generic_params dsg.generics));
- let tr_self, all_generics =
- match call.trait_method_generics with
- | None -> (UnknownTrait __FUNCTION__, generics)
- | Some (all_generics, tr_self) ->
- let all_generics =
- ctx_translate_fwd_generic_args ctx all_generics
- in
- let tr_self =
- translate_fwd_trait_instance_id ctx.type_ctx.type_infos
- tr_self
+ (* We need to compute the signatures of the backward functions. *)
+ let sg = Option.get call.sg in
+ let decls_ctx = ctx.decls_ctx in
+ let dsg =
+ translate_fun_sig_with_regions_hierarchy_to_decomposed decls_ctx fid
+ call.regions_hierarchy sg
+ (List.map (fun _ -> None) sg.inputs)
+ in
+ log#ldebug
+ (lazy ("dsg.generics:\n" ^ show_generic_params dsg.generics));
+ let tr_self, all_generics =
+ match call.trait_method_generics with
+ | None -> (UnknownTrait __FUNCTION__, generics)
+ | Some (all_generics, tr_self) ->
+ let all_generics =
+ ctx_translate_fwd_generic_args ctx all_generics
+ in
+ let tr_self =
+ translate_fwd_trait_instance_id ctx.type_ctx.type_infos
+ tr_self
+ in
+ (tr_self, all_generics)
+ in
+ let back_tys =
+ compute_back_tys_with_info dsg (Some (all_generics, tr_self))
+ in
+ (* Introduce variables for the backward functions *)
+ (* Compute a proper basename for the variables *)
+ let back_fun_name =
+ let name =
+ match fid with
+ | FunId (FAssumed fid) -> (
+ match fid with
+ | BoxNew -> "box_new"
+ | BoxFree -> "box_free"
+ | ArrayRepeat -> "array_repeat"
+ | ArrayIndexShared -> "index_shared"
+ | ArrayIndexMut -> "index_mut"
+ | ArrayToSliceShared -> "to_slice_shared"
+ | ArrayToSliceMut -> "to_slice_mut"
+ | SliceIndexShared -> "index_shared"
+ | SliceIndexMut -> "index_mut")
+ | FunId (FRegular fid) | TraitMethod (_, _, fid) -> (
+ let decl =
+ FunDeclId.Map.find fid ctx.fun_ctx.llbc_fun_decls
in
- (tr_self, all_generics)
+ match Collections.List.last decl.name with
+ | PeIdent (s, _) -> s
+ | PeImpl _ ->
+ (* We shouldn't get there *)
+ raise (Failure "Unexpected"))
in
- let back_tys =
- compute_back_tys_with_info dsg (Some (all_generics, tr_self))
- in
- (* Introduce variables for the backward functions *)
- (* Compute a proper basename for the variables *)
- let back_fun_name =
- let name =
- match fid with
- | FunId (FAssumed fid) -> (
- match fid with
- | BoxNew -> "box_new"
- | BoxFree -> "box_free"
- | ArrayRepeat -> "array_repeat"
- | ArrayIndexShared -> "index_shared"
- | ArrayIndexMut -> "index_mut"
- | ArrayToSliceShared -> "to_slice_shared"
- | ArrayToSliceMut -> "to_slice_mut"
- | SliceIndexShared -> "index_shared"
- | SliceIndexMut -> "index_mut")
- | FunId (FRegular fid) | TraitMethod (_, _, fid) -> (
- let decl =
- FunDeclId.Map.find fid ctx.fun_ctx.llbc_fun_decls
- in
- match Collections.List.last decl.name with
- | PeIdent (s, _) -> s
- | PeImpl _ ->
- (* We shouldn't get there *)
- raise (Failure "Unexpected"))
- in
- name ^ "_back"
- in
- let ctx, back_vars =
- fresh_opt_vars
- (List.map
- (fun ty ->
- match ty with
- | None -> None
- | Some (back_sg, ty) ->
- (* We insert a name for the variable only if the function
- can fail: if it can fail, it means the call returns a backward
- function. Otherwise, we it directly returns the value given
- back by the backward function, which means we shouldn't
- give it a name like "back..." (it doesn't make sense) *)
- let name =
- if back_sg.effect_info.can_fail then
- Some back_fun_name
- else None
- in
- Some (name, ty))
- back_tys)
- ctx
- in
- let back_funs =
- List.filter_map
- (fun v ->
- match v with
- | None -> None
- | Some v -> Some (mk_typed_pattern_from_var v None))
- back_vars
- in
- let gids =
- List.map
- (fun (g : T.region_var_group) -> g.id)
- call.regions_hierarchy
- in
- let back_vars =
- List.map (Option.map mk_texpression_from_var) back_vars
- in
- let back_funs_map =
- RegionGroupId.Map.of_list (List.combine gids back_vars)
- in
- (ctx, dsg.fwd_info.ignore_output, Some back_funs_map, back_funs))
- else (ctx, false, None, [])
+ name ^ "_back"
+ in
+ let ctx, back_vars =
+ fresh_opt_vars
+ (List.map
+ (fun ty ->
+ match ty with
+ | None -> None
+ | Some (back_sg, ty) ->
+ (* We insert a name for the variable only if the function
+ can fail: if it can fail, it means the call returns a backward
+ function. Otherwise, we it directly returns the value given
+ back by the backward function, which means we shouldn't
+ give it a name like "back..." (it doesn't make sense) *)
+ let name =
+ if back_sg.effect_info.can_fail then Some back_fun_name
+ else None
+ in
+ Some (name, ty))
+ back_tys)
+ ctx
+ in
+ let back_funs =
+ List.filter_map
+ (fun v ->
+ match v with
+ | None -> None
+ | Some v -> Some (mk_typed_pattern_from_var v None))
+ back_vars
+ in
+ let gids =
+ List.map
+ (fun (g : T.region_var_group) -> g.id)
+ call.regions_hierarchy
+ in
+ let back_vars =
+ List.map (Option.map mk_texpression_from_var) back_vars
+ in
+ let back_funs_map =
+ RegionGroupId.Map.of_list (List.combine gids back_vars)
+ in
+ (ctx, dsg.fwd_info.ignore_output, Some back_funs_map, back_funs)
in
(* Compute the pattern for the destination *)
let ctx, dest = fresh_var_for_symbolic_value call.dest ctx in
@@ -2407,19 +2375,6 @@ and translate_end_abstraction_fun_call (ectx : C.eval_ctx) (abs : V.abs)
raise (Failure "Unreachable")
in
let effect_info = get_fun_effect_info ctx fun_id None (Some rg_id) in
- let generics = ctx_translate_fwd_generic_args ctx call.generics in
- (* Retrieve the original call and the parent abstractions *)
- let _forward, backwards = get_abs_ancestors ctx abs call_id in
- (* Retrieve the values consumed when we called the forward function and
- * ended the parent backward functions: those give us part of the input
- * values (rem: for now, as we disallow nested lifetimes, there can't be
- * parent backward functions).
- * Note that the forward inputs **include the fuel and the input state**
- * (if we use those). *)
- let fwd_inputs = call_info.forward_inputs in
- let back_ancestors_inputs =
- List.concat (List.map (fun (_abs, args) -> args) backwards)
- in
(* Retrieve the values consumed upon ending the loans inside this
* abstraction: those give us the remaining input values *)
let back_inputs = abs_to_consumed ctx ectx abs in
@@ -2434,11 +2389,6 @@ and translate_end_abstraction_fun_call (ectx : C.eval_ctx) (abs : V.abs)
([ back_state ], ctx, Some nstate)
else ([], ctx, None)
in
- (* Concatenate all the inpus *)
- let inherited_inputs =
- if !Config.return_back_funs then []
- else List.concat [ fwd_inputs; back_ancestors_inputs ]
- in
let back_inputs = List.append back_inputs back_state in
(* Retrieve the values given back by this function: those are the output
* values. We rely on the fact that there are no nested borrows to use the
@@ -2459,58 +2409,33 @@ and translate_end_abstraction_fun_call (ectx : C.eval_ctx) (abs : V.abs)
(* Retrieve the function id, and register the function call in the context
if necessary.Arith_status *)
let ctx, func =
- bs_ctx_register_backward_call abs effect_info call_id rg_id inherited_inputs
- back_inputs generics output.ty ctx
+ bs_ctx_register_backward_call abs call_id rg_id back_inputs ctx
in
(* Translate the next expression *)
let next_e = translate_expression e ctx in
(* Put everything together *)
- let inputs = List.append inherited_inputs back_inputs in
+ let inputs = back_inputs in
let args_mplaces = List.map (fun _ -> None) inputs in
let args =
List.map
(fun (arg, mp) -> mk_opt_mplace_texpression mp arg)
(List.combine inputs args_mplaces)
in
- (* **Optimization**:
- =================
- We do a small optimization here if we split the forward/backward functions.
- If the backward function doesn't have any output, we don't introduce any function
- call.
- See the comment in {!Config.filter_useless_monadic_calls}.
-
- TODO: use an option to disallow backward functions from updating the state.
- TODO: a backward function which only gives back shared borrows shouldn't
- update the state (state updates should only be used for mutable borrows,
- with objects like Rc for instance).
- *)
- if
- (not !Config.return_back_funs)
- && !Config.filter_useless_monadic_calls
- && outputs = [] && nstate = None
- then (
- (* No outputs - we do a small sanity check: the backward function
- should have exactly the same number of inputs as the forward:
- this number can be different only if the forward function returned
- a value containing mutable borrows, which can't be the case... *)
- assert (List.length inputs = List.length fwd_inputs);
- next_e)
- else
- (* The backward function might also have been filtered if we do not
- split the forward/backward functions *)
- match func with
- | None -> next_e
- | Some func ->
- log#ldebug
- (lazy
- (let args = List.map (texpression_to_string ctx) args in
- "func: "
- ^ texpression_to_string ctx func
- ^ "\nfunc type: "
- ^ pure_ty_to_string ctx func.ty
- ^ "\n\nargs:\n" ^ String.concat "\n" args));
- let call = mk_apps func args in
- mk_let effect_info.can_fail output call next_e
+ (* The backward function might have been filtered it does nothing
+ (consumes unit and returns unit). *)
+ match func with
+ | None -> next_e
+ | Some func ->
+ log#ldebug
+ (lazy
+ (let args = List.map (texpression_to_string ctx) args in
+ "func: "
+ ^ texpression_to_string ctx func
+ ^ "\nfunc type: "
+ ^ pure_ty_to_string ctx func.ty
+ ^ "\n\nargs:\n" ^ String.concat "\n" args));
+ let call = mk_apps func args in
+ mk_let effect_info.can_fail output call next_e
and translate_end_abstraction_identity (ectx : C.eval_ctx) (abs : V.abs)
(e : S.expression) (ctx : bs_ctx) : texpression =
@@ -2614,8 +2539,6 @@ and translate_end_abstraction_loop (ectx : C.eval_ctx) (abs : V.abs)
get_fun_effect_info ctx (FunId fun_id) (Some vloop_id) (Some rg_id)
in
let loop_info = LoopId.Map.find loop_id ctx.loops in
- let generics = loop_info.generics in
- let fwd_inputs = Option.get loop_info.forward_inputs in
(* Retrieve the additional backward inputs. Note that those are actually
the backward inputs of the function we are synthesizing (and that we
need to *transmit* to the loop backward function): they are not the
@@ -2637,10 +2560,7 @@ and translate_end_abstraction_loop (ectx : C.eval_ctx) (abs : V.abs)
else ([], ctx, None)
in
(* Concatenate all the inputs *)
- let inputs =
- if !Config.return_back_funs then List.concat [ back_inputs; back_state ]
- else List.concat [ fwd_inputs; back_inputs; back_state ]
- in
+ let inputs = List.concat [ back_inputs; back_state ] in
(* Retrieve the values given back by this function *)
let ctx, outputs = abs_to_given_back None abs ctx in
(* Group the output values together: first the updated inputs *)
@@ -2660,87 +2580,52 @@ and translate_end_abstraction_loop (ectx : C.eval_ctx) (abs : V.abs)
(fun (arg, mp) -> mk_opt_mplace_texpression mp arg)
(List.combine inputs args_mplaces)
in
- let input_tys = (List.map (fun (x : texpression) -> x.ty)) args in
- let ret_ty =
- if effect_info.can_fail then mk_result_ty output.ty else output.ty
- in
(* Create the expression for the function:
- it is either a call to a top-level function, if we split the
forward/backward functions
- or a call to the variable we introduced for the backward function,
if we merge the forward/backward functions *)
let func =
- if !Config.return_back_funs then
- RegionGroupId.Map.find rg_id (Option.get loop_info.back_funs)
- else
- let func_ty = mk_arrows input_tys ret_ty in
- let func = Fun (FromLlbc (FunId fun_id, Some loop_id, Some rg_id)) in
- let func = { id = FunOrOp func; generics } in
- Some { e = Qualif func; ty = func_ty }
+ RegionGroupId.Map.find rg_id (Option.get loop_info.back_funs)
in
- (* **Optimization**:
- =================
- We do a small optimization here in case we split the forward/backward
- functions.
- If the backward function doesn't have any output, we don't introduce
- any function call.
- See the comment in {!Config.filter_useless_monadic_calls}.
-
- TODO: use an option to disallow backward functions from updating the state.
- TODO: a backward function which only gives back shared borrows shouldn't
- update the state (state updates should only be used for mutable borrows,
- with objects like Rc for instance).
- *)
- if
- (not !Config.return_back_funs)
- && !Config.filter_useless_monadic_calls
- && outputs = [] && nstate = None
- then (
- (* No outputs - we do a small sanity check: the backward function
- should have exactly the same number of inputs as the forward:
- this number can be different only if the forward function returned
- a value containing mutable borrows, which can't be the case... *)
- assert (List.length inputs = List.length fwd_inputs);
- next_e)
- else
- (* In case we merge the fwd/back functions we filter the backward
- functions elsewhere *)
- match func with
- | None -> next_e
- | Some func ->
- let call = mk_apps func args in
- (* Add meta-information - this is slightly hacky: we look at the
- values consumed by the abstraction (note that those come from
- *before* we applied the fixed-point context) and use them to
- guide the naming of the output vars.
-
- Also, we need to convert the backward outputs from patterns to
- variables.
-
- Finally, in practice, this works well only for loop bodies:
- we do this only in this case.
- TODO: improve the heuristics, to give weight to the hints for
- instance.
- *)
- let next_e =
- if ctx.inside_loop then
- let consumed_values = abs_to_consumed ctx ectx abs in
- let var_values = List.combine outputs consumed_values in
- let var_values =
- List.filter_map
- (fun (var, v) ->
- match var.Pure.value with
- | PatVar (var, _) -> Some (var, v)
- | _ -> None)
- var_values
- in
- let vars, values = List.split var_values in
- mk_emeta_symbolic_assignments vars values next_e
- else next_e
- in
+ (* We may have filtered the backward function elsewhere if it doesn't
+ do anything (doesn't consume anything and doesn't return anything) *)
+ match func with
+ | None -> next_e
+ | Some func ->
+ let call = mk_apps func args in
+ (* Add meta-information - this is slightly hacky: we look at the
+ values consumed by the abstraction (note that those come from
+ *before* we applied the fixed-point context) and use them to
+ guide the naming of the output vars.
+
+ Also, we need to convert the backward outputs from patterns to
+ variables.
+
+ Finally, in practice, this works well only for loop bodies:
+ we do this only in this case.
+ TODO: improve the heuristics, to give weight to the hints for
+ instance.
+ *)
+ let next_e =
+ if ctx.inside_loop then
+ let consumed_values = abs_to_consumed ctx ectx abs in
+ let var_values = List.combine outputs consumed_values in
+ let var_values =
+ List.filter_map
+ (fun (var, v) ->
+ match var.Pure.value with
+ | PatVar (var, _) -> Some (var, v)
+ | _ -> None)
+ var_values
+ in
+ let vars, values = List.split var_values in
+ mk_emeta_symbolic_assignments vars values next_e
+ else next_e
+ in
- (* Create the let-binding *)
- mk_let effect_info.can_fail output call next_e)
+ (* Create the let-binding *)
+ mk_let effect_info.can_fail output call next_e)
and translate_global_eval (gid : A.GlobalDeclId.id) (sval : V.symbolic_value)
(e : S.expression) (ctx : bs_ctx) : texpression =
@@ -3068,48 +2953,40 @@ and translate_forward_end (ectx : C.eval_ctx)
*)
let ctx =
(* Introduce variables for the inputs and the state variable
- and update the context. *)
- if !Config.return_back_funs then
- (* If the forward/backward functions are not split, we need
- to introduce fresh variables for the additional inputs,
- because they are locally introduced in a lambda *)
- let back_sg = RegionGroupId.Map.find bid ctx.sg.back_sg in
- let ctx, backward_inputs_no_state =
- fresh_vars back_sg.inputs_no_state ctx
- in
- let ctx, backward_inputs_with_state =
- if back_sg.effect_info.stateful then
- let ctx, var, _ = bs_ctx_fresh_state_var ctx in
- (ctx, backward_inputs_no_state @ [ var ])
- else (ctx, backward_inputs_no_state)
- in
- {
- ctx with
- backward_inputs_no_state =
- RegionGroupId.Map.add bid backward_inputs_no_state
- ctx.backward_inputs_no_state;
- backward_inputs_with_state =
- RegionGroupId.Map.add bid backward_inputs_with_state
- ctx.backward_inputs_with_state;
- }
- else
- (* Update the state variable *)
- let back_state_var =
- RegionGroupId.Map.find bid ctx.back_state_vars
- in
- { ctx with state_var = back_state_var }
+ and update the context.
+
+ We need to introduce fresh variables for the additional inputs,
+ because they are locally introduced in a lambda.
+ *)
+ let back_sg = RegionGroupId.Map.find bid ctx.sg.back_sg in
+ let ctx, backward_inputs_no_state =
+ fresh_vars back_sg.inputs_no_state ctx
+ in
+ let ctx, backward_inputs_with_state =
+ if back_sg.effect_info.stateful then
+ let ctx, var, _ = bs_ctx_fresh_state_var ctx in
+ (ctx, backward_inputs_no_state @ [ var ])
+ else (ctx, backward_inputs_no_state)
+ in
+ {
+ ctx with
+ backward_inputs_no_state =
+ RegionGroupId.Map.add bid backward_inputs_no_state
+ ctx.backward_inputs_no_state;
+ backward_inputs_with_state =
+ RegionGroupId.Map.add bid backward_inputs_with_state
+ ctx.backward_inputs_with_state;
+ }
in
let e = T.RegionGroupId.Map.find bid back_e in
let finish e =
(* Wrap in lambdas if necessary *)
- if !Config.return_back_funs then
- let inputs =
- RegionGroupId.Map.find bid ctx.backward_inputs_with_state
- in
- let places = List.map (fun _ -> None) inputs in
- mk_lambdas_from_vars inputs places e
- else e
+ let inputs =
+ RegionGroupId.Map.find bid ctx.backward_inputs_with_state
+ in
+ let places = List.map (fun _ -> None) inputs in
+ mk_lambdas_from_vars inputs places e
in
(ctx, e, finish)
in
@@ -3131,85 +3008,83 @@ and translate_forward_end (ectx : C.eval_ctx)
function, if needs be, and lookup the proper expression.
*)
let translate_end ctx =
- if !Config.return_back_funs then
- (* Compute the output of the forward function *)
- let fwd_effect_info = ctx.sg.fwd_info.effect_info in
- let ctx, pure_fwd_var = fresh_var None ctx.sg.fwd_output ctx in
- let fwd_e = translate_one_end ctx None in
-
- (* Introduce the backward functions. *)
- let back_el =
- List.map
- (fun ((gid, _) : RegionGroupId.id * back_sg_info) ->
- translate_one_end ctx (Some gid))
- (RegionGroupId.Map.bindings ctx.sg.back_sg)
- in
+ (* Compute the output of the forward function *)
+ let fwd_effect_info = ctx.sg.fwd_info.effect_info in
+ let ctx, pure_fwd_var = fresh_var None ctx.sg.fwd_output ctx in
+ let fwd_e = translate_one_end ctx None in
- (* Compute whether the backward expressions should be evaluated straight
- away or not (i.e., if we should bind them with monadic let-bindings
- or not). We evaluate them straight away if they can fail and have no
- inputs. *)
- let evaluate_backs =
- List.map
- (fun (sg : back_sg_info) ->
- if !Config.simplify_merged_fwd_backs then
- sg.inputs = [] && sg.effect_info.can_fail
- else false)
- (RegionGroupId.Map.values ctx.sg.back_sg)
- in
+ (* Introduce the backward functions. *)
+ let back_el =
+ List.map
+ (fun ((gid, _) : RegionGroupId.id * back_sg_info) ->
+ translate_one_end ctx (Some gid))
+ (RegionGroupId.Map.bindings ctx.sg.back_sg)
+ in
- (* Introduce variables for the backward functions.
- We lookup the LLBC definition in an attempt to derive pretty names
- for those functions. *)
- let _, back_vars = fresh_back_vars_for_current_fun ctx in
+ (* Compute whether the backward expressions should be evaluated straight
+ away or not (i.e., if we should bind them with monadic let-bindings
+ or not). We evaluate them straight away if they can fail and have no
+ inputs. *)
+ let evaluate_backs =
+ List.map
+ (fun (sg : back_sg_info) ->
+ if !Config.simplify_merged_fwd_backs then
+ sg.inputs = [] && sg.effect_info.can_fail
+ else false)
+ (RegionGroupId.Map.values ctx.sg.back_sg)
+ in
- (* Create the return expressions *)
- let vars =
- let back_vars = List.filter_map (fun x -> x) back_vars in
- if ctx.sg.fwd_info.ignore_output then back_vars
- else pure_fwd_var :: back_vars
- in
- let vars = List.map mk_texpression_from_var vars in
- let ret = mk_simpl_tuple_texpression vars in
-
- (* Introduce a fresh input state variable for the forward expression *)
- let _ctx, state_var, state_pat =
- if fwd_effect_info.stateful then
- let ctx, var, pat = bs_ctx_fresh_state_var ctx in
- (ctx, [ var ], [ pat ])
- else (ctx, [], [])
- in
+ (* Introduce variables for the backward functions.
+ We lookup the LLBC definition in an attempt to derive pretty names
+ for those functions. *)
+ let _, back_vars = fresh_back_vars_for_current_fun ctx in
- let state_var = List.map mk_texpression_from_var state_var in
- let ret = mk_simpl_tuple_texpression (state_var @ [ ret ]) in
- let ret = mk_result_return_texpression ret in
+ (* Create the return expressions *)
+ let vars =
+ let back_vars = List.filter_map (fun x -> x) back_vars in
+ if ctx.sg.fwd_info.ignore_output then back_vars
+ else pure_fwd_var :: back_vars
+ in
+ let vars = List.map mk_texpression_from_var vars in
+ let ret = mk_simpl_tuple_texpression vars in
+
+ (* Introduce a fresh input state variable for the forward expression *)
+ let _ctx, state_var, state_pat =
+ if fwd_effect_info.stateful then
+ let ctx, var, pat = bs_ctx_fresh_state_var ctx in
+ (ctx, [ var ], [ pat ])
+ else (ctx, [], [])
+ in
- (* Introduce all the let-bindings *)
+ let state_var = List.map mk_texpression_from_var state_var in
+ let ret = mk_simpl_tuple_texpression (state_var @ [ ret ]) in
+ let ret = mk_result_return_texpression ret in
- (* Combine:
- - the backward variables
- - whether we should evaluate the expression for the backward function
- (i.e., should we use a monadic let-binding or not - we do if the
- backward functions don't have inputs and can fail)
- - the expressions for the backward functions
- *)
- let back_vars_els =
- List.filter_map
- (fun (v, (eval, el)) ->
- match v with None -> None | Some v -> Some (v, eval, el))
- (List.combine back_vars (List.combine evaluate_backs back_el))
- in
- let e =
- List.fold_right
- (fun (var, evaluate, back_e) e ->
- mk_let evaluate (mk_typed_pattern_from_var var None) back_e e)
- back_vars_els ret
- in
- (* Bind the expression for the forward output *)
- let fwd_var = mk_typed_pattern_from_var pure_fwd_var None in
- let pat = mk_simpl_tuple_pattern (state_pat @ [ fwd_var ]) in
- mk_let fwd_effect_info.can_fail pat fwd_e e
- else translate_one_end ctx ctx.bid
+ (* Introduce all the let-bindings *)
+
+ (* Combine:
+ - the backward variables
+ - whether we should evaluate the expression for the backward function
+ (i.e., should we use a monadic let-binding or not - we do if the
+ backward functions don't have inputs and can fail)
+ - the expressions for the backward functions
+ *)
+ let back_vars_els =
+ List.filter_map
+ (fun (v, (eval, el)) ->
+ match v with None -> None | Some v -> Some (v, eval, el))
+ (List.combine back_vars (List.combine evaluate_backs back_el))
+ in
+ let e =
+ List.fold_right
+ (fun (var, evaluate, back_e) e ->
+ mk_let evaluate (mk_typed_pattern_from_var var None) back_e e)
+ back_vars_els ret
+ in
+ (* Bind the expression for the forward output *)
+ let fwd_var = mk_typed_pattern_from_var pure_fwd_var None in
+ let pat = mk_simpl_tuple_pattern (state_pat @ [ fwd_var ]) in
+ mk_let fwd_effect_info.can_fail pat fwd_e e
in
(* If we are (re-)entering a loop, we need to introduce a call to the
@@ -3279,24 +3154,22 @@ and translate_forward_end (ectx : C.eval_ctx)
backward functions of the outer function.
*)
let ctx, back_funs_map, back_funs =
- if !Config.return_back_funs then
- let ctx, back_vars = fresh_back_vars_for_current_fun ctx in
- let back_funs =
- List.filter_map
- (fun v ->
- match v with
- | None -> None
- | Some v -> Some (mk_typed_pattern_from_var v None))
- back_vars
- in
- let gids = RegionGroupId.Map.keys ctx.sg.back_sg in
- let back_funs_map =
- RegionGroupId.Map.of_list
- (List.combine gids
- (List.map (Option.map mk_texpression_from_var) back_vars))
- in
- (ctx, Some back_funs_map, back_funs)
- else (ctx, None, [])
+ let ctx, back_vars = fresh_back_vars_for_current_fun ctx in
+ let back_funs =
+ List.filter_map
+ (fun v ->
+ match v with
+ | None -> None
+ | Some v -> Some (mk_typed_pattern_from_var v None))
+ back_vars
+ in
+ let gids = RegionGroupId.Map.keys ctx.sg.back_sg in
+ let back_funs_map =
+ RegionGroupId.Map.of_list
+ (List.combine gids
+ (List.map (Option.map mk_texpression_from_var) back_vars))
+ in
+ (ctx, Some back_funs_map, back_funs)
in
(* Introduce patterns *)
@@ -3339,7 +3212,7 @@ and translate_forward_end (ectx : C.eval_ctx)
let out_pat = mk_simpl_tuple_pattern out_pats in
let loop_call =
- let fun_id = Fun (FromLlbc (FunId fid, Some loop_id, None)) in
+ let fun_id = Fun (FromLlbc (FunId fid, Some loop_id)) in
let func = { id = FunOrOp fun_id; generics = loop_info.generics } in
let input_tys = (List.map (fun (x : texpression) -> x.ty)) args in
let ret_ty =
@@ -3438,91 +3311,58 @@ and translate_loop (loop : S.loop) (ctx : bs_ctx) : texpression =
(* The output type of the loop function *)
let fwd_effect_info = { ctx.sg.fwd_info.effect_info with is_rec = true } in
let back_effect_infos, output_ty =
- if !Config.return_back_funs then
- (* The loop backward functions consume the same additional inputs as the parent
- function, but have custom outputs *)
- let back_sgs = RegionGroupId.Map.bindings ctx.sg.back_sg in
- let given_back_tys = RegionGroupId.Map.values rg_to_given_back_tys in
- let back_info_tys =
- List.map
- (fun (((id, back_sg), given_back) : (_ * back_sg_info) * ty list) ->
- (* Remark: the effect info of the backward function for the loop
- is almost the same as for the backward function of the parent function.
- Quite importantly, the fact that the function is stateful and/or can fail
- mostly depends on whether it has inputs or not, and the backward functions
- for the loops have the same inputs as the backward functions for the parent
- function.
- *)
- let effect_info = back_sg.effect_info in
- let effect_info = { effect_info with is_rec = true } in
- (* Compute the input/output types *)
- let inputs = List.map snd back_sg.inputs in
- let outputs = given_back in
- (* Filter if necessary *)
- let ty =
- if
- !Config.simplify_merged_fwd_backs && inputs = [] && outputs = []
- then None
- else
- let output = mk_simpl_tuple_ty outputs in
- let output =
- mk_back_output_ty_from_effect_info effect_info inputs output
- in
- let ty = mk_arrows inputs output in
- Some ty
- in
- ((id, effect_info), ty))
- (List.combine back_sgs given_back_tys)
- in
- let back_info = List.map fst back_info_tys in
- let back_info = RegionGroupId.Map.of_list back_info in
- let back_tys = List.filter_map snd back_info_tys in
- let output =
- if ctx.sg.fwd_info.ignore_output then back_tys
- else ctx.sg.fwd_output :: back_tys
- in
- let output = mk_simpl_tuple_ty output in
- let effect_info = ctx.sg.fwd_info.effect_info in
- let output =
- if effect_info.stateful then mk_simpl_tuple_ty [ mk_state_ty; output ]
- else output
- in
- let output =
- if effect_info.can_fail && inputs <> [] then mk_result_ty output
- else output
- in
- (back_info, output)
- else
- let back_info =
- RegionGroupId.Map.of_list
- (List.map
- (fun ((id, back_sg) : _ * back_sg_info) ->
- (id, { back_sg.effect_info with is_rec = true }))
- (RegionGroupId.Map.bindings ctx.sg.back_sg))
- in
- let output =
- match ctx.bid with
- | None ->
- (* Forward function: same type as the parent function *)
- (translate_fun_sig_from_decomposed ctx.sg None).output
- | Some rg_id ->
- (* Backward function: custom return type *)
- let doutputs =
- T.RegionGroupId.Map.find rg_id rg_to_given_back_tys
- in
- let output = mk_simpl_tuple_ty doutputs in
- let fwd_effect_info = ctx.sg.fwd_info.effect_info in
- let output =
- if fwd_effect_info.stateful then
- mk_simpl_tuple_ty [ mk_state_ty; output ]
- else output
- in
- let output =
- if fwd_effect_info.can_fail then mk_result_ty output else output
- in
- output
- in
- (back_info, output)
+ (* The loop backward functions consume the same additional inputs as the parent
+ function, but have custom outputs *)
+ let back_sgs = RegionGroupId.Map.bindings ctx.sg.back_sg in
+ let given_back_tys = RegionGroupId.Map.values rg_to_given_back_tys in
+ let back_info_tys =
+ List.map
+ (fun (((id, back_sg), given_back) : (_ * back_sg_info) * ty list) ->
+ (* Remark: the effect info of the backward function for the loop
+ is almost the same as for the backward function of the parent function.
+ Quite importantly, the fact that the function is stateful and/or can fail
+ mostly depends on whether it has inputs or not, and the backward functions
+ for the loops have the same inputs as the backward functions for the parent
+ function.
+ *)
+ let effect_info = back_sg.effect_info in
+ let effect_info = { effect_info with is_rec = true } in
+ (* Compute the input/output types *)
+ let inputs = List.map snd back_sg.inputs in
+ let outputs = given_back in
+ (* Filter if necessary *)
+ let ty =
+ if !Config.simplify_merged_fwd_backs && inputs = [] && outputs = []
+ then None
+ else
+ let output = mk_simpl_tuple_ty outputs in
+ let output =
+ mk_back_output_ty_from_effect_info effect_info inputs output
+ in
+ let ty = mk_arrows inputs output in
+ Some ty
+ in
+ ((id, effect_info), ty))
+ (List.combine back_sgs given_back_tys)
+ in
+ let back_info = List.map fst back_info_tys in
+ let back_info = RegionGroupId.Map.of_list back_info in
+ let back_tys = List.filter_map snd back_info_tys in
+ let output =
+ if ctx.sg.fwd_info.ignore_output then back_tys
+ else ctx.sg.fwd_output :: back_tys
+ in
+ let output = mk_simpl_tuple_ty output in
+ let effect_info = ctx.sg.fwd_info.effect_info in
+ let output =
+ if effect_info.stateful then mk_simpl_tuple_ty [ mk_state_ty; output ]
+ else output
+ in
+ let output =
+ if effect_info.can_fail && inputs <> [] then mk_result_ty output
+ else output
+ in
+ (back_info, output)
in
(* Add the loop information in the context *)
@@ -3708,31 +3548,26 @@ let wrap_in_match_fuel (fuel0 : VarId.id) (fuel : VarId.id) (body : texpression)
let translate_fun_decl (ctx : bs_ctx) (body : S.expression option) : fun_decl =
(* Translate *)
let def = ctx.fun_decl in
- let bid = ctx.bid in
+ assert (ctx.bid = None);
log#ldebug
(lazy
("SymbolicToPure.translate_fun_decl: "
^ name_to_string ctx def.name
- ^ " ("
- ^ Print.option_to_string T.RegionGroupId.to_string bid
- ^ ")\n"));
+ ^ "\n"));
(* Translate the declaration *)
let def_id = def.def_id in
let llbc_name = def.name in
let name = name_to_string ctx llbc_name in
(* Translate the signature *)
- let signature = translate_fun_sig_from_decomposed ctx.sg ctx.bid in
- let regions_hierarchy =
- FunIdMap.find (FRegular def_id) ctx.fun_ctx.regions_hierarchies
- in
+ let signature = translate_fun_sig_from_decomposed ctx.sg in
(* Translate the body, if there is *)
let body =
match body with
| None -> None
| Some body ->
let effect_info =
- get_fun_effect_info ctx (FunId (FRegular def_id)) None bid
+ get_fun_effect_info ctx (FunId (FRegular def_id)) None None
in
let body = translate_expression body ctx in
(* Add a match over the fuel, if necessary *)
@@ -3760,37 +3595,8 @@ let translate_fun_decl (ctx : bs_ctx) (body : S.expression option) : fun_decl =
if effect_info.stateful_group then [ mk_state_var ctx.state_var ]
else []
in
- (* Compute the list of (properly ordered) backward input variables *)
- let backward_inputs : var list =
- match bid with
- | None -> []
- | Some back_id ->
- assert (not !Config.return_back_funs);
- let parents_ids =
- list_ordered_ancestor_region_groups regions_hierarchy back_id
- in
- let backward_ids = List.append parents_ids [ back_id ] in
- List.concat
- (List.map
- (fun id ->
- T.RegionGroupId.Map.find id ctx.backward_inputs_no_state)
- backward_ids)
- in
- (* Introduce the backward input state (the state at call site of the
- * *backward* function), if necessary *)
- let back_state =
- if effect_info.stateful && Option.is_some bid then
- let state_var =
- RegionGroupId.Map.find (Option.get bid) ctx.back_state_vars
- in
- [ mk_state_var state_var ]
- else []
- in
(* Group the inputs together *)
- let inputs =
- List.concat
- [ fuel; ctx.forward_inputs; fwd_state; backward_inputs; back_state ]
- in
+ let inputs = List.concat [ fuel; ctx.forward_inputs; fwd_state ] in
let inputs_lvs =
List.map (fun v -> mk_typed_pattern_from_var v None) inputs
in
@@ -3799,16 +3605,10 @@ let translate_fun_decl (ctx : bs_ctx) (body : S.expression option) : fun_decl =
(lazy
("SymbolicToPure.translate_fun_decl: "
^ name_to_string ctx def.name
- ^ " ("
- ^ Print.option_to_string T.RegionGroupId.to_string bid
- ^ ")" ^ "\n- forward_inputs: "
+ ^ "\n- inputs: "
^ String.concat ", " (List.map show_var ctx.forward_inputs)
- ^ "\n- fwd_state: "
+ ^ "\n- state: "
^ String.concat ", " (List.map show_var fwd_state)
- ^ "\n- backward_inputs: "
- ^ String.concat ", " (List.map show_var backward_inputs)
- ^ "\n- back_state: "
- ^ String.concat ", " (List.map show_var back_state)
^ "\n- signature.inputs: "
^ String.concat ", "
(List.map (pure_ty_to_string ctx) signature.inputs)));
@@ -3837,7 +3637,6 @@ let translate_fun_decl (ctx : bs_ctx) (body : S.expression option) : fun_decl =
kind = def.kind;
num_loops;
loop_id;
- back_id = bid;
llbc_name;
name;
signature;
diff --git a/compiler/Translate.ml b/compiler/Translate.ml
index 55a94302..c12de045 100644
--- a/compiler/Translate.ml
+++ b/compiler/Translate.ml
@@ -1,5 +1,4 @@
open Interpreter
-open Expressions
open Types
open Values
open LlbcAst
@@ -49,8 +48,6 @@ let translate_function_to_pure (trans_ctx : trans_ctx)
log#ldebug
(lazy ("translate_function_to_pure: " ^ name_to_string trans_ctx fdef.name));
- let def_id = fdef.def_id in
-
(* Compute the symbolic ASTs, if the function is transparent *)
let symbolic_trans = translate_function_to_symbolics trans_ctx fdef in
@@ -124,20 +121,7 @@ let translate_function_to_pure (trans_ctx : trans_ctx)
SymbolicToPure.translate_fun_sig_from_decl_to_decomposed trans_ctx fdef
in
- let regions_hierarchy =
- LlbcAstUtils.FunIdMap.find (FRegular def_id) fun_ctx.regions_hierarchies
- in
-
- let var_counter, back_state_vars =
- if !Config.return_back_funs then (var_counter, [])
- else
- List.fold_left_map
- (fun var_counter (region_vars : region_var_group) ->
- let gid = region_vars.id in
- let var, var_counter = Pure.VarId.fresh var_counter in
- (var_counter, (gid, var)))
- var_counter regions_hierarchy
- in
+ let var_counter, back_state_vars = (var_counter, []) in
let back_state_vars = RegionGroupId.Map.of_list back_state_vars in
let ctx =
@@ -195,28 +179,7 @@ let translate_function_to_pure (trans_ctx : trans_ctx)
in
(* Add the backward inputs *)
- let ctx, backward_inputs_no_state, backward_inputs_with_state =
- if !Config.return_back_funs then (ctx, [], [])
- else
- let ctx, inputs_no_with_state =
- List.fold_left_map
- (fun ctx (region_vars : region_var_group) ->
- let gid = region_vars.id in
- let back_sg = RegionGroupId.Map.find gid sg.back_sg in
- let ctx, no_state =
- SymbolicToPure.fresh_vars back_sg.inputs_no_state ctx
- in
- let ctx, with_state =
- SymbolicToPure.fresh_vars back_sg.inputs ctx
- in
- (ctx, ((gid, no_state), (gid, with_state))))
- ctx regions_hierarchy
- in
- let inputs_no_state, inputs_with_state =
- List.split inputs_no_with_state
- in
- (ctx, inputs_no_state, inputs_with_state)
- in
+ let backward_inputs_no_state, backward_inputs_with_state = ([], []) in
let backward_inputs_no_state =
RegionGroupId.Map.of_list backward_inputs_no_state
in
@@ -225,40 +188,10 @@ let translate_function_to_pure (trans_ctx : trans_ctx)
in
let ctx = { ctx with backward_inputs_no_state; backward_inputs_with_state } in
- (* Translate the forward function *)
- let pure_forward =
- match symbolic_trans with
- | None -> SymbolicToPure.translate_fun_decl ctx None
- | Some (_, ast) -> SymbolicToPure.translate_fun_decl ctx (Some ast)
- in
-
- (* Translate the backward functions, if we split the forward and backward functions *)
- let translate_backward (rg : region_var_group) : Pure.fun_decl =
- (* For the backward inputs/outputs initialization: we use the fact that
- * there are no nested borrows for now, and so that the region groups
- * can't have parents *)
- assert (rg.parents = []);
- let back_id = rg.id in
-
- match symbolic_trans with
- | None ->
- (* Initialize the context *)
- let ctx = { ctx with bid = Some back_id } in
- (* Translate *)
- SymbolicToPure.translate_fun_decl ctx None
- | Some (_, symbolic) ->
- (* Initialize the context *)
- let ctx = { ctx with bid = Some back_id } in
- (* Translate *)
- SymbolicToPure.translate_fun_decl ctx (Some symbolic)
- in
- let pure_backwards =
- if !Config.return_back_funs then []
- else List.map translate_backward regions_hierarchy
- in
-
- (* Return *)
- (pure_forward, pure_backwards)
+ (* Translate the function *)
+ match symbolic_trans with
+ | None -> SymbolicToPure.translate_fun_decl ctx None
+ | Some (_, ast) -> SymbolicToPure.translate_fun_decl ctx (Some ast)
(* TODO: factor out the return type *)
let translate_crate_to_pure (crate : crate) :
@@ -513,9 +446,8 @@ let export_global (fmt : Format.formatter) (config : gen_config) (ctx : gen_ctx)
let global_decls = ctx.trans_ctx.global_ctx.global_decls in
let global = GlobalDeclId.Map.find id global_decls in
let trans = FunDeclId.Map.find global.body ctx.trans_funs in
- assert (trans.fwd.loops = []);
- assert (trans.backs = []);
- let body = trans.fwd.f in
+ assert (trans.loops = []);
+ let body = trans.f in
let is_opaque = Option.is_none body.Pure.body in
(* Check if we extract the global *)
@@ -643,7 +575,7 @@ let export_functions_group (fmt : Format.formatter) (config : gen_config)
let funs_map = builtin_funs_map () in
List.map
(fun (trans : pure_fun_translation) ->
- match_name_find_opt ctx.trans_ctx trans.fwd.f.llbc_name funs_map <> None)
+ match_name_find_opt ctx.trans_ctx trans.f.llbc_name funs_map <> None)
pure_ls
in
@@ -660,7 +592,7 @@ let export_functions_group (fmt : Format.formatter) (config : gen_config)
(* Extract the decrease clauses template bodies *)
if config.extract_template_decreases_clauses then
List.iter
- (fun { fwd; _ } ->
+ (fun f ->
(* We only generate decreases clauses for the forward functions, because
the termination argument should only depend on the forward inputs.
The backward functions thus use the same decreases clauses as the
@@ -687,27 +619,14 @@ let export_functions_group (fmt : Format.formatter) (config : gen_config)
raise
(Failure "HOL4 doesn't have decreases/termination clauses")
in
- extract_decrease fwd.f;
- List.iter extract_decrease fwd.loops)
+ extract_decrease f.f;
+ List.iter extract_decrease f.loops)
pure_ls;
- (* Concatenate the function definitions, filtering the useless forward
- * functions. *)
+ (* Flatten the translated functions (concatenate the functions with
+ the declarations introduced for the loops) *)
let decls =
- List.concat
- (List.map
- (fun { keep_fwd; fwd; backs } ->
- let fwd =
- if keep_fwd then List.append fwd.loops [ fwd.f ] else []
- in
- let backs : Pure.fun_decl list =
- List.concat
- (List.map
- (fun back -> List.append back.loops [ back.f ])
- backs)
- in
- List.append fwd backs)
- pure_ls)
+ List.concat (List.map (fun f -> List.append f.loops [ f.f ]) pure_ls)
in
(* Extract the function definitions *)
@@ -724,9 +643,7 @@ let export_functions_group (fmt : Format.formatter) (config : gen_config)
(* Insert unit tests if necessary *)
if config.test_trans_unit_functions then
List.iter
- (fun trans ->
- if trans.keep_fwd then
- Extract.extract_unit_test_if_unit_fun ctx fmt trans.fwd.f)
+ (fun trans -> Extract.extract_unit_test_if_unit_fun ctx fmt trans.f)
pure_ls
(** Export a trait declaration. *)
@@ -812,7 +729,7 @@ let extract_definitions (fmt : Format.formatter) (config : gen_config)
extract their type directly in the records we generate for
the trait declarations themselves, there is no point in having
separate type definitions) *)
- match pure_fun.fwd.f.Pure.kind with
+ match pure_fun.f.Pure.kind with
| TraitMethodDecl _ -> ()
| _ ->
(* Translate *)
@@ -1001,18 +918,18 @@ let translate_crate (filename : string) (dest_dir : string) (crate : crate) :
* whether we should generate a decrease clause or not. *)
let rec_functions =
List.map
- (fun { fwd; _ } ->
- let fwd_f =
- if fwd.f.Pure.signature.fwd_info.effect_info.is_rec then
- [ (fwd.f.def_id, None) ]
+ (fun trans ->
+ let f =
+ if trans.f.Pure.signature.fwd_info.effect_info.is_rec then
+ [ (trans.f.def_id, None) ]
else []
in
- let loop_fwds =
+ let loops =
List.map
(fun (def : Pure.fun_decl) -> [ (def.def_id, def.loop_id) ])
- fwd.loops
+ trans.loops
in
- fwd_f :: loop_fwds)
+ f :: loops)
trans_funs
in
let rec_functions : PureUtils.fun_loop_id list =
@@ -1028,7 +945,7 @@ let translate_crate (filename : string) (dest_dir : string) (crate : crate) :
let trans_funs : pure_fun_translation FunDeclId.Map.t =
FunDeclId.Map.of_list
(List.map
- (fun (trans : pure_fun_translation) -> (trans.fwd.f.def_id, trans))
+ (fun (trans : pure_fun_translation) -> (trans.f.def_id, trans))
trans_funs)
in
@@ -1052,7 +969,6 @@ let translate_crate (filename : string) (dest_dir : string) (crate : crate) :
names_maps;
indent_incr = 2;
use_dep_ite = !Config.backend = Lean && !Config.extract_decreases_clauses;
- fun_name_info = PureUtils.RegularFunIdMap.empty;
trait_decl_id = None (* None by default *);
is_provided_method = false (* false by default *);
trans_trait_decls;
@@ -1082,7 +998,6 @@ let translate_crate (filename : string) (dest_dir : string) (crate : crate) :
(fun ctx (trans : pure_fun_translation) ->
(* If requested by the user, register termination measures and decreases
proofs for all the recursive functions *)
- let fwd_def = trans.fwd.f in
let gen_decr_clause (def : Pure.fun_decl) =
!Config.extract_decreases_clauses
&& PureUtils.FunLoopIdSet.mem
@@ -1091,7 +1006,7 @@ let translate_crate (filename : string) (dest_dir : string) (crate : crate) :
in
(* Register the names, only if the function is not a global body -
* those are handled later *)
- let is_global = fwd_def.Pure.is_global_decl_body in
+ let is_global = trans.f.Pure.is_global_decl_body in
if is_global then ctx
else Extract.extract_fun_decl_register_names ctx gen_decr_clause trans)
ctx
@@ -1171,13 +1086,7 @@ let translate_crate (filename : string) (dest_dir : string) (crate : crate) :
let exe_dir = Filename.dirname Sys.argv.(0) in
let primitives_src_dest =
match !Config.backend with
- | FStar ->
- let src =
- if !Config.return_back_funs then
- "/backends/fstar/merge/Primitives.fst"
- else "/backends/fstar/split/Primitives.fst"
- in
- Some (src, "Primitives.fst")
+ | FStar -> Some ("/backends/fstar/merge/Primitives.fst", "Primitives.fst")
| Coq -> Some ("/backends/coq/Primitives.v", "Primitives.v")
| Lean -> None
| HOL4 -> None
diff --git a/compiler/TranslateCore.ml b/compiler/TranslateCore.ml
index 88438872..05877b5a 100644
--- a/compiler/TranslateCore.ml
+++ b/compiler/TranslateCore.ml
@@ -8,19 +8,8 @@ let log = Logging.translate_log
type trans_ctx = decls_ctx [@@deriving show]
type fun_and_loops = { f : Pure.fun_decl; loops : Pure.fun_decl list }
-type pure_fun_translation_no_loops = Pure.fun_decl * Pure.fun_decl list
-
-type pure_fun_translation = {
- keep_fwd : bool;
- (** Should we extract the forward function?
-
- If the forward function returns `()` and there is exactly one
- backward function, we may merge the forward into the backward
- function and thus don't extract the forward function)?
- *)
- fwd : fun_and_loops;
- backs : fun_and_loops list;
-}
+type pure_fun_translation_no_loops = Pure.fun_decl
+type pure_fun_translation = fun_and_loops
let trans_ctx_to_fmt_env (ctx : trans_ctx) : Print.fmt_env =
Print.Contexts.decls_ctx_to_fmt_env ctx