summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSon Ho2024-03-08 16:06:35 +0100
committerSon Ho2024-03-08 16:06:35 +0100
commitfe2a2cb34148e46e32cdcfbf100e38d9986082cd (patch)
treea378134d495985718b842a786bad573695974b95
parentbc154dda94c44b3ae67a3b04d3866cc473aead32 (diff)
Make progress on propagating the changes
-rw-r--r--compiler/Config.ml44
-rw-r--r--compiler/Contexts.ml2
-rw-r--r--compiler/Extract.ml444
-rw-r--r--compiler/ExtractBase.ml161
-rw-r--r--compiler/ExtractBuiltin.ml11
-rw-r--r--compiler/ExtractTypes.ml2
-rw-r--r--compiler/Main.ml6
-rw-r--r--compiler/PrintPure.ml18
-rw-r--r--compiler/Pure.ml4
-rw-r--r--compiler/PureMicroPasses.ml331
-rw-r--r--compiler/ReorderDecls.ml12
-rw-r--r--compiler/SymbolicToPure.ml13
-rw-r--r--compiler/Translate.ml145
-rw-r--r--compiler/TranslateCore.ml15
14 files changed, 259 insertions, 949 deletions
diff --git a/compiler/Config.ml b/compiler/Config.ml
index 6fd866e8..af0e62d1 100644
--- a/compiler/Config.ml
+++ b/compiler/Config.ml
@@ -263,50 +263,6 @@ let decompose_nested_let_patterns = ref false
*)
let unfold_monadic_let_bindings = ref false
-(** Controls whether we try to filter the calls to monadic functions
- (which can fail) when their outputs are not used.
-
- The useless calls are calls to backward functions which have no outputs.
- This case happens if the original Rust function only takes *shared* borrows
- as inputs, and is thus pretty common.
-
- We are allowed to do this only because in this specific case,
- the backward function fails *exactly* when the forward function fails
- (they actually do exactly the same thing, the only difference being
- that the forward function can potentially return a value), and upon
- reaching the place where we should introduce a call to the backward
- function, we know we have introduced a call to the forward function.
-
- Also note that in general, backward functions "do more things" than
- forward functions, and have more opportunities to fail (even though
- in the generated code, calls to the backward functions should fail
- exactly when the corresponding, previous call to the forward functions
- failed).
-
- This optimization is done in {!SymbolicToPure}. We might want to move it to
- the micro-passes subsequent to the translation from symbolic to pure, but it
- is really super easy to do it when going from symbolic to pure. Note that
- we later filter the useless *forward* calls in the micro-passes, where it is
- more natural to do.
-
- See the comments for {!PureMicroPasses.expression_contains_child_call_in_all_paths}
- for additional explanations.
- *)
-let filter_useless_monadic_calls = ref true
-
-(** If {!filter_useless_monadic_calls} is activated, some functions
- become useless: if this option is true, we don't extract them.
-
- The calls to functions which always get filtered are:
- - the forward functions with unit return value
- - the backward functions which don't output anything (backward
- functions coming from rust functions with no mutable borrows
- as input values - note that if a function doesn't take mutable
- borrows as inputs, it can't return mutable borrows; we actually
- dynamically check for that).
- *)
-let filter_useless_functions = ref true
-
(** Simplify the forward/backward functions, in case we merge them
(i.e., the forward functions return the backward functions).
diff --git a/compiler/Contexts.ml b/compiler/Contexts.ml
index b1dd9553..54411fd5 100644
--- a/compiler/Contexts.ml
+++ b/compiler/Contexts.ml
@@ -109,6 +109,8 @@ let reset_global_counters () =
region_id_counter := RegionId.generator_zero;
abstraction_id_counter := AbstractionId.generator_zero;
loop_id_counter := LoopId.generator_zero;
+ (* We want the loop id to start at 1 *)
+ let _ = fresh_loop_id () in
fun_call_id_counter := FunCallId.generator_zero;
dummy_var_id_counter := DummyVarId.generator_zero
diff --git a/compiler/Extract.ml b/compiler/Extract.ml
index dbca4f8f..794a1bfa 100644
--- a/compiler/Extract.ml
+++ b/compiler/Extract.ml
@@ -9,8 +9,7 @@ open TranslateCore
open Config
include ExtractTypes
-(** Compute the names for all the pure functions generated from a rust function
- (forward function and backward functions).
+(** Compute the names for all the pure functions generated from a rust function.
*)
let extract_fun_decl_register_names (ctx : extraction_ctx)
(has_decreases_clause : fun_decl -> bool) (def : pure_fun_translation) :
@@ -19,63 +18,36 @@ let extract_fun_decl_register_names (ctx : extraction_ctx)
method implementations): we do not need to refer to them directly. We will
only use their type for the fields of the records we generate for the trait
declarations *)
- match def.fwd.f.kind with
+ match def.f.kind with
| TraitMethodDecl _ -> ctx
| _ -> (
(* Check if the function is builtin *)
let builtin =
let open ExtractBuiltin in
let funs_map = builtin_funs_map () in
- match_name_find_opt ctx.trans_ctx def.fwd.f.llbc_name funs_map
+ match_name_find_opt ctx.trans_ctx def.f.llbc_name funs_map
in
(* Use the builtin names if necessary *)
match builtin with
- | Some (filter_info, info) ->
- (* Register the filtering information, if there is *)
+ | Some (filter_info, fun_info) ->
+ (* Builtin function: register the filtering information, if there is *)
let ctx =
match filter_info with
| Some keep ->
{
ctx with
funs_filter_type_args_map =
- FunDeclId.Map.add def.fwd.f.def_id keep
+ FunDeclId.Map.add def.f.def_id keep
ctx.funs_filter_type_args_map;
}
| _ -> ctx
in
- let funs =
- if !Config.return_back_funs then [ def.fwd.f ]
- else
- let backs = List.map (fun f -> f.f) def.backs in
- if def.keep_fwd then def.fwd.f :: backs else backs
- in
- List.fold_left
- (fun ctx (f : fun_decl) ->
- let open ExtractBuiltin in
- let fun_id =
- (Pure.FunId (FRegular f.def_id), f.loop_id, f.back_id)
- in
- let fun_info =
- List.find_opt
- (fun (x : builtin_fun_info) -> x.rg = f.back_id)
- info
- in
- match fun_info with
- | Some fun_info ->
- ctx_add (FunId (FromLlbc fun_id)) fun_info.extract_name ctx
- | None ->
- raise
- (Failure
- ("Not found: "
- ^ name_to_string ctx f.llbc_name
- ^ ", "
- ^ Print.option_to_string Pure.show_loop_id f.loop_id
- ^ Print.option_to_string Pure.show_region_group_id
- f.back_id)))
- ctx funs
+ let f = def.f in
+ let open ExtractBuiltin in
+ let fun_id = (Pure.FunId (FRegular f.def_id), f.loop_id) in
+ ctx_add (FunId (FromLlbc fun_id)) fun_info.extract_name ctx
| None ->
- let fwd = def.fwd in
- let backs = def.backs in
+ (* Not builtin *)
(* Register the decrease clauses, if necessary *)
let register_decreases ctx def =
if has_decreases_clause def then
@@ -88,21 +60,15 @@ let extract_fun_decl_register_names (ctx : extraction_ctx)
| Lean -> ctx_add_decreases_proof def ctx
else ctx
in
- let ctx =
- List.fold_left register_decreases ctx (fwd.f :: fwd.loops)
- in
- let register_fun ctx f = ctx_add_fun_decl def f ctx in
+ (* We have to register the function itself, and the loops it
+ may contain (which are extracted as functions) *)
+ let funs = def.f :: def.loops in
+ (* Register the decrease clauses *)
+ let ctx = List.fold_left register_decreases ctx funs in
+ (* Register the name of the function and the loops *)
+ let register_fun ctx f = ctx_add_fun_decl f ctx in
let register_funs ctx fl = List.fold_left register_fun ctx fl in
- (* Register the names of the forward functions *)
- let ctx =
- if def.keep_fwd then register_funs ctx (fwd.f :: fwd.loops) else ctx
- in
- (* Register the names of the backward functions *)
- List.fold_left
- (fun ctx { f = back; loops = loop_backs } ->
- let ctx = register_fun ctx back in
- register_funs ctx loop_backs)
- ctx backs)
+ register_funs ctx funs)
(** Simply add the global name to the context. *)
let extract_global_decl_register_names (ctx : extraction_ctx)
@@ -230,7 +196,7 @@ let fun_builtin_filter_types (id : FunDeclId.id) (types : 'a list)
let decl = FunDeclId.Map.find id ctx.trans_funs in
let err =
"Ill-formed builtin information for function "
- ^ name_to_string ctx decl.fwd.f.llbc_name
+ ^ name_to_string ctx decl.f.llbc_name
^ ": "
^ string_of_int (List.length filter)
^ " filtering arguments provided for "
@@ -460,8 +426,7 @@ and extract_function_call (ctx : extraction_ctx) (fmt : F.formatter)
]}
*)
(match fun_id with
- | FromLlbc
- (TraitMethod (trait_ref, method_name, _fun_decl_id), lp_id, rg_id) ->
+ | FromLlbc (TraitMethod (trait_ref, method_name, _fun_decl_id), lp_id) ->
(* We have to check whether the trait method is required or provided *)
let trait_decl_id = trait_ref.trait_decl_ref.trait_decl_id in
let trait_decl =
@@ -477,7 +442,7 @@ and extract_function_call (ctx : extraction_ctx) (fmt : F.formatter)
extract_trait_ref ctx fmt TypeDeclId.Set.empty true trait_ref;
let fun_name =
ctx_get_trait_method trait_ref.trait_decl_ref.trait_decl_id
- method_name rg_id ctx
+ method_name ctx
in
let add_brackets (s : string) =
if !backend = Coq then "(" ^ s ^ ")" else s
@@ -486,9 +451,7 @@ and extract_function_call (ctx : extraction_ctx) (fmt : F.formatter)
else
(* Provided method: we see it as a regular function call, and use
the function name *)
- let fun_id =
- FromLlbc (FunId (FRegular method_id.id), lp_id, rg_id)
- in
+ let fun_id = FromLlbc (FunId (FRegular method_id.id), lp_id) in
let fun_name = ctx_get_function fun_id ctx in
F.pp_print_string fmt fun_name;
@@ -513,7 +476,7 @@ and extract_function_call (ctx : extraction_ctx) (fmt : F.formatter)
*)
let types =
match fun_id with
- | FromLlbc (FunId (FRegular id), _, _) ->
+ | FromLlbc (FunId (FRegular id), _) ->
fun_builtin_filter_types id generics.types ctx
| _ -> Result.Ok generics.types
in
@@ -1392,11 +1355,6 @@ let extract_template_lean_termination_and_decreasing (ctx : extraction_ctx)
let extract_fun_comment (ctx : extraction_ctx) (fmt : F.formatter)
(def : fun_decl) : unit =
- let { keep_fwd; num_backs } =
- PureUtils.RegularFunIdMap.find
- (Pure.FunId (FRegular def.def_id), def.loop_id, def.back_id)
- ctx.fun_name_info
- in
let comment_pre = "[" ^ name_to_string ctx def.llbc_name ^ "]:" in
let comment =
let loop_comment =
@@ -1404,23 +1362,7 @@ let extract_fun_comment (ctx : extraction_ctx) (fmt : F.formatter)
| None -> ""
| Some id -> " loop " ^ LoopId.to_string id ^ ":"
in
- let fwd_back_comment =
- match def.back_id with
- | None -> if !Config.return_back_funs then [] else [ "forward function" ]
- | Some id ->
- (* Check if there is only one backward function, and no forward function *)
- if (not keep_fwd) && num_backs = 1 then
- [
- "merged forward/backward function";
- "(there is a single backward function, and the forward function \
- returns ())";
- ]
- else [ "backward function " ^ T.RegionGroupId.to_string id ]
- in
- match fwd_back_comment with
- | [] -> [ comment_pre ^ loop_comment ]
- | [ s ] -> [ comment_pre ^ loop_comment ^ " " ^ s ]
- | s :: sl -> (comment_pre ^ loop_comment ^ " " ^ s) :: sl
+ [ comment_pre ^ loop_comment ]
in
extract_comment_with_span fmt comment def.meta.span
@@ -1435,9 +1377,7 @@ let extract_fun_decl_gen (ctx : extraction_ctx) (fmt : F.formatter)
(kind : decl_kind) (has_decreases_clause : bool) (def : fun_decl) : unit =
assert (not def.is_global_decl_body);
(* Retrieve the function name *)
- let def_name =
- ctx_get_local_function def.def_id def.loop_id def.back_id ctx
- in
+ let def_name = ctx_get_local_function def.def_id def.loop_id ctx in
(* Add a break before *)
if !backend <> HOL4 || not (decl_is_first_from_group kind) then
F.pp_print_break fmt 0 0;
@@ -1681,9 +1621,7 @@ let extract_fun_decl_gen (ctx : extraction_ctx) (fmt : F.formatter)
let extract_fun_decl_hol4_opaque (ctx : extraction_ctx) (fmt : F.formatter)
(def : fun_decl) : unit =
(* Retrieve the definition name *)
- let def_name =
- ctx_get_local_function def.def_id def.loop_id def.back_id ctx
- in
+ let def_name = ctx_get_local_function def.def_id def.loop_id ctx in
assert (def.signature.generics.const_generics = []);
(* Add the type/const gen parameters - note that we need those bindings
only for the generation of the type (they are not top-level) *)
@@ -1870,7 +1808,6 @@ let extract_global_decl_hol4_opaque (ctx : extraction_ctx) (fmt : F.formatter)
let extract_global_decl (ctx : extraction_ctx) (fmt : F.formatter)
(global : A.global_decl) (body : fun_decl) (interface : bool) : unit =
assert body.is_global_decl_body;
- assert (Option.is_none body.back_id);
assert (body.signature.inputs = []);
assert (body.signature.generics = empty_generic_params);
@@ -1883,9 +1820,7 @@ let extract_global_decl (ctx : extraction_ctx) (fmt : F.formatter)
let decl_name = ctx_get_global global.def_id ctx in
let body_name =
- ctx_get_function
- (FromLlbc (Pure.FunId (FRegular global.body), None, None))
- ctx
+ ctx_get_function (FromLlbc (Pure.FunId (FRegular global.body), None)) ctx
in
let decl_ty, body_ty =
@@ -2058,80 +1993,45 @@ let extract_trait_decl_method_names (ctx : extraction_ctx)
let required_methods = trait_decl.required_methods in
(* Compute the names *)
let method_names =
- (* We add one field per required forward/backward function *)
- let get_funs_for_id (id : fun_decl_id) : fun_decl list =
- let trans : pure_fun_translation = FunDeclId.Map.find id ctx.trans_funs in
- if !Config.return_back_funs then [ trans.fwd.f ]
- else List.map (fun f -> f.f) (trans.fwd :: trans.backs)
- in
match builtin_info with
| None ->
- (* We add one field per required forward/backward function *)
- let compute_item_names (item_name : string) (id : fun_decl_id) :
- string * (RegionGroupId.id option * string) list =
- let compute_fun_name (f : fun_decl) : RegionGroupId.id option * string
- =
- (* We do something special to reuse the [ctx_compute_fun_decl]
- function. TODO: make it cleaner. *)
- let llbc_name : Types.name =
- [ Types.PeIdent (item_name, Disambiguator.zero) ]
- in
- let f = { f with llbc_name } in
- let trans = A.FunDeclId.Map.find f.def_id ctx.trans_funs in
- let name = ctx_compute_fun_name trans f ctx in
- (* Add a prefix if necessary *)
- let name =
- if !Config.record_fields_short_names then name
- else ctx_compute_trait_decl_name ctx trait_decl ^ "_" ^ name
- in
- (f.back_id, name)
+ (* Not a builtin function *)
+ let compute_item_name (item_name : string) (id : fun_decl_id) :
+ string * string =
+ let trans : pure_fun_translation =
+ FunDeclId.Map.find id ctx.trans_funs
+ in
+ let f = trans.f in
+ (* We do something special to reuse the [ctx_compute_fun_decl]
+ function. TODO: make it cleaner. *)
+ let llbc_name : Types.name =
+ [ Types.PeIdent (item_name, Disambiguator.zero) ]
+ in
+ let f = { f with llbc_name } in
+ let name = ctx_compute_fun_name f ctx in
+ (* Add a prefix if necessary *)
+ let name =
+ if !Config.record_fields_short_names then name
+ else ctx_compute_trait_decl_name ctx trait_decl ^ "_" ^ name
in
- let funs = get_funs_for_id id in
- (item_name, List.map compute_fun_name funs)
+ (item_name, name)
in
- List.map (fun (name, id) -> compute_item_names name id) required_methods
+ List.map (fun (name, id) -> compute_item_name name id) required_methods
| Some info ->
+ (* This is a builtin *)
let funs_map = StringMap.of_list info.methods in
List.map
- (fun (item_name, fun_id) ->
+ (fun (item_name, _) ->
let open ExtractBuiltin in
let info = StringMap.find item_name funs_map in
- let trans_funs = get_funs_for_id fun_id in
- let find (trans_fun : fun_decl) =
- let info =
- List.find_opt
- (fun (info : builtin_fun_info) -> info.rg = trans_fun.back_id)
- info
- in
- match info with
- | Some info -> (info.rg, info.extract_name)
- | None ->
- let err =
- "Ill-formed builtin information for trait decl \""
- ^ name_to_string ctx trait_decl.llbc_name
- ^ "\", method \"" ^ item_name
- ^ "\": could not find name for region "
- ^ Print.option_to_string Pure.show_region_group_id
- trans_fun.back_id
- in
- log#serror err;
- if !Config.fail_hard then raise (Failure err)
- else (trans_fun.back_id, "%ERROR_BUILTIN_NAME_NOT_FOUND%")
- in
- let rg_with_name_list = List.map find trans_funs in
- (item_name, rg_with_name_list))
+ let fun_name = info.extract_name in
+ (item_name, fun_name))
required_methods
in
(* Register the names *)
List.fold_left
- (fun ctx (item_name, funs) ->
- (* We add one field per required forward/backward function *)
- List.fold_left
- (fun ctx (rg, fun_name) ->
- ctx_add
- (TraitMethodId (trait_decl.def_id, item_name, rg))
- fun_name ctx)
- ctx funs)
+ (fun ctx (item_name, fun_name) ->
+ ctx_add (TraitMethodId (trait_decl.def_id, item_name)) fun_name ctx)
ctx method_names
(** Similar to {!extract_type_decl_register_names} *)
@@ -2263,46 +2163,41 @@ let extract_trait_decl_method_items (ctx : extraction_ctx) (fmt : F.formatter)
(* Lookup the definition *)
let trans = A.FunDeclId.Map.find id ctx.trans_funs in
(* Extract the items *)
- let funs = if trans.keep_fwd then trans.fwd :: trans.backs else trans.backs in
- let extract_method (f : fun_and_loops) =
- let f = f.f in
- let fun_name = ctx_get_trait_method decl.def_id item_name f.back_id ctx in
- let ty () =
- (* Extract the generics *)
- (* We need to add the generics specific to the method, by removing those
- which actually apply to the trait decl *)
- let generics =
- let drop_trait_clauses = false in
- generic_params_drop_prefix ~drop_trait_clauses decl.generics
- f.signature.generics
- in
- (* Note that we do not filter the LLBC generic parameters.
- This is ok because:
- - we only use them to find meaningful names for the trait clauses
- - we only generate trait clauses for the clauses we find in the
- pure generics *)
- let ctx, type_params, cg_params, trait_clauses =
- ctx_add_generic_params f.llbc_name f.signature.llbc_generics generics
- ctx
- in
- let backend_uses_forall =
- match !backend with Coq | Lean -> true | FStar | HOL4 -> false
- in
- let generics_not_empty = generics <> empty_generic_params in
- let use_forall = generics_not_empty && backend_uses_forall in
- let use_arrows = generics_not_empty && not backend_uses_forall in
- let use_forall_use_sep = false in
- extract_generic_params ctx fmt TypeDeclId.Set.empty ~use_forall
- ~use_forall_use_sep ~use_arrows generics type_params cg_params
- trait_clauses;
- if use_forall then F.pp_print_string fmt ",";
- (* Extract the inputs and output *)
- F.pp_print_space fmt ();
- extract_fun_inputs_output_parameters_types ctx fmt f
+ let f = trans.f in
+ let fun_name = ctx_get_trait_method decl.def_id item_name ctx in
+ let ty () =
+ (* Extract the generics *)
+ (* We need to add the generics specific to the method, by removing those
+ which actually apply to the trait decl *)
+ let generics =
+ let drop_trait_clauses = false in
+ generic_params_drop_prefix ~drop_trait_clauses decl.generics
+ f.signature.generics
+ in
+ (* Note that we do not filter the LLBC generic parameters.
+ This is ok because:
+ - we only use them to find meaningful names for the trait clauses
+ - we only generate trait clauses for the clauses we find in the
+ pure generics *)
+ let ctx, type_params, cg_params, trait_clauses =
+ ctx_add_generic_params f.llbc_name f.signature.llbc_generics generics ctx
in
- extract_trait_decl_item ctx fmt fun_name ty
+ let backend_uses_forall =
+ match !backend with Coq | Lean -> true | FStar | HOL4 -> false
+ in
+ let generics_not_empty = generics <> empty_generic_params in
+ let use_forall = generics_not_empty && backend_uses_forall in
+ let use_arrows = generics_not_empty && not backend_uses_forall in
+ let use_forall_use_sep = false in
+ extract_generic_params ctx fmt TypeDeclId.Set.empty ~use_forall
+ ~use_forall_use_sep ~use_arrows generics type_params cg_params
+ trait_clauses;
+ if use_forall then F.pp_print_string fmt ",";
+ (* Extract the inputs and output *)
+ F.pp_print_space fmt ();
+ extract_fun_inputs_output_parameters_types ctx fmt f
in
- List.iter extract_method funs
+ extract_trait_decl_item ctx fmt fun_name ty
(** Extract a trait declaration *)
let extract_trait_decl (ctx : extraction_ctx) (fmt : F.formatter)
@@ -2494,21 +2389,10 @@ let extract_trait_decl_coq_arguments (ctx : extraction_ctx) (fmt : F.formatter)
decl.parent_clauses;
(* The required methods *)
List.iter
- (fun (item_name, id) ->
- (* Lookup the definition *)
- let trans = A.FunDeclId.Map.find id ctx.trans_funs in
+ (fun (item_name, _) ->
(* Extract the items *)
- let funs =
- if trans.keep_fwd then trans.fwd :: trans.backs else trans.backs
- in
- let extract_for_method (f : fun_and_loops) =
- let f = f.f in
- let item_name =
- ctx_get_trait_method decl.def_id item_name f.back_id ctx
- in
- extract_coq_arguments_instruction ctx fmt item_name num_params
- in
- List.iter extract_for_method funs)
+ let item_name = ctx_get_trait_method decl.def_id item_name ctx in
+ extract_coq_arguments_instruction ctx fmt item_name num_params)
decl.required_methods;
(* Add a space *)
F.pp_print_space fmt ())
@@ -2531,75 +2415,71 @@ let extract_trait_impl_method_items (ctx : extraction_ctx) (fmt : F.formatter)
(* Lookup the definition *)
let trans = A.FunDeclId.Map.find id ctx.trans_funs in
(* Extract the items *)
- let funs = if trans.keep_fwd then trans.fwd :: trans.backs else trans.backs in
- let extract_method (f : fun_and_loops) =
- let f = f.f in
- let fun_name = ctx_get_trait_method trait_decl_id item_name f.back_id ctx in
- let ty () =
- (* Filter the generics if the method is a builtin *)
- let i_tys, _, _ = impl_generics in
- let impl_types, i_tys, f_tys =
- match FunDeclId.Map.find_opt f.def_id ctx.funs_filter_type_args_map with
- | None -> (impl.generics.types, i_tys, f.signature.generics.types)
- | Some filter ->
- let filter_list filter ls =
- let ls = List.combine filter ls in
- List.filter_map (fun (b, ty) -> if b then Some ty else None) ls
- in
- let impl_types = impl.generics.types in
- let impl_filter =
- Collections.List.prefix (List.length impl_types) filter
- in
- let i_tys = i_tys in
- let i_filter = Collections.List.prefix (List.length i_tys) filter in
- ( filter_list impl_filter impl_types,
- filter_list i_filter i_tys,
- filter_list filter f.signature.generics.types )
- in
- let f_generics = { f.signature.generics with types = f_tys } in
- (* Extract the generics - we need to quantify over the generics which
- are specific to the method, and call it will all the generics
- (trait impl + method generics) *)
- let f_generics =
- let drop_trait_clauses = true in
- generic_params_drop_prefix ~drop_trait_clauses
- { impl.generics with types = impl_types }
- f_generics
- in
- (* Register and print the quantified generics.
-
- Note that we do not filter the LLBC generic parameters.
- This is ok because:
- - we only use them to find meaningful names for the trait clauses
- - we only generate trait clauses for the clauses we find in the
- pure generics *)
- let ctx, f_tys, f_cgs, f_tcs =
- ctx_add_generic_params f.llbc_name f.signature.llbc_generics f_generics
- ctx
- in
- let use_forall = f_generics <> empty_generic_params in
- extract_generic_params ctx fmt TypeDeclId.Set.empty ~use_forall f_generics
- f_tys f_cgs f_tcs;
- if use_forall then F.pp_print_string fmt ",";
- (* Extract the function call *)
- F.pp_print_space fmt ();
- let fun_name = ctx_get_local_function f.def_id None f.back_id ctx in
- F.pp_print_string fmt fun_name;
- let all_generics =
- let _, i_cgs, i_tcs = impl_generics in
- List.concat [ i_tys; f_tys; i_cgs; f_cgs; i_tcs; f_tcs ]
- in
-
- (* Filter the generics if the function is builtin *)
- List.iter
- (fun p ->
- F.pp_print_space fmt ();
- F.pp_print_string fmt p)
- all_generics
+ let f = trans.f in
+ let fun_name = ctx_get_trait_method trait_decl_id item_name ctx in
+ let ty () =
+ (* Filter the generics if the method is a builtin *)
+ let i_tys, _, _ = impl_generics in
+ let impl_types, i_tys, f_tys =
+ match FunDeclId.Map.find_opt f.def_id ctx.funs_filter_type_args_map with
+ | None -> (impl.generics.types, i_tys, f.signature.generics.types)
+ | Some filter ->
+ let filter_list filter ls =
+ let ls = List.combine filter ls in
+ List.filter_map (fun (b, ty) -> if b then Some ty else None) ls
+ in
+ let impl_types = impl.generics.types in
+ let impl_filter =
+ Collections.List.prefix (List.length impl_types) filter
+ in
+ let i_tys = i_tys in
+ let i_filter = Collections.List.prefix (List.length i_tys) filter in
+ ( filter_list impl_filter impl_types,
+ filter_list i_filter i_tys,
+ filter_list filter f.signature.generics.types )
+ in
+ let f_generics = { f.signature.generics with types = f_tys } in
+ (* Extract the generics - we need to quantify over the generics which
+ are specific to the method, and call it will all the generics
+ (trait impl + method generics) *)
+ let f_generics =
+ let drop_trait_clauses = true in
+ generic_params_drop_prefix ~drop_trait_clauses
+ { impl.generics with types = impl_types }
+ f_generics
+ in
+ (* Register and print the quantified generics.
+
+ Note that we do not filter the LLBC generic parameters.
+ This is ok because:
+ - we only use them to find meaningful names for the trait clauses
+ - we only generate trait clauses for the clauses we find in the
+ pure generics *)
+ let ctx, f_tys, f_cgs, f_tcs =
+ ctx_add_generic_params f.llbc_name f.signature.llbc_generics f_generics
+ ctx
+ in
+ let use_forall = f_generics <> empty_generic_params in
+ extract_generic_params ctx fmt TypeDeclId.Set.empty ~use_forall f_generics
+ f_tys f_cgs f_tcs;
+ if use_forall then F.pp_print_string fmt ",";
+ (* Extract the function call *)
+ F.pp_print_space fmt ();
+ let fun_name = ctx_get_local_function f.def_id None ctx in
+ F.pp_print_string fmt fun_name;
+ let all_generics =
+ let _, i_cgs, i_tcs = impl_generics in
+ List.concat [ i_tys; f_tys; i_cgs; f_cgs; i_tcs; f_tcs ]
in
- extract_trait_impl_item ctx fmt fun_name ty
+
+ (* Filter the generics if the function is builtin *)
+ List.iter
+ (fun p ->
+ F.pp_print_space fmt ();
+ F.pp_print_string fmt p)
+ all_generics
in
- List.iter extract_method funs
+ extract_trait_impl_item ctx fmt fun_name ty
(** Extract a trait implementation *)
let extract_trait_impl (ctx : extraction_ctx) (fmt : F.formatter)
@@ -2766,8 +2646,6 @@ let extract_trait_impl (ctx : extraction_ctx) (fmt : F.formatter)
*)
let extract_unit_test_if_unit_fun (ctx : extraction_ctx) (fmt : F.formatter)
(def : fun_decl) : unit =
- (* We only insert unit tests for forward functions *)
- assert (def.back_id = None);
(* Check if this is a unit function *)
let sg = def.signature in
if
@@ -2791,9 +2669,7 @@ let extract_unit_test_if_unit_fun (ctx : extraction_ctx) (fmt : F.formatter)
F.pp_print_string fmt "assert_norm";
F.pp_print_space fmt ();
F.pp_print_string fmt "(";
- let fun_name =
- ctx_get_local_function def.def_id def.loop_id def.back_id ctx
- in
+ let fun_name = ctx_get_local_function def.def_id def.loop_id ctx in
F.pp_print_string fmt fun_name;
if sg.inputs <> [] then (
F.pp_print_space fmt ();
@@ -2807,9 +2683,7 @@ let extract_unit_test_if_unit_fun (ctx : extraction_ctx) (fmt : F.formatter)
F.pp_print_string fmt "Check";
F.pp_print_space fmt ();
F.pp_print_string fmt "(";
- let fun_name =
- ctx_get_local_function def.def_id def.loop_id def.back_id ctx
- in
+ let fun_name = ctx_get_local_function def.def_id def.loop_id ctx in
F.pp_print_string fmt fun_name;
if sg.inputs <> [] then (
F.pp_print_space fmt ();
@@ -2820,9 +2694,7 @@ let extract_unit_test_if_unit_fun (ctx : extraction_ctx) (fmt : F.formatter)
F.pp_print_string fmt "#assert";
F.pp_print_space fmt ();
F.pp_print_string fmt "(";
- let fun_name =
- ctx_get_local_function def.def_id def.loop_id def.back_id ctx
- in
+ let fun_name = ctx_get_local_function def.def_id def.loop_id ctx in
F.pp_print_string fmt fun_name;
if sg.inputs <> [] then (
F.pp_print_space fmt ();
@@ -2835,9 +2707,7 @@ let extract_unit_test_if_unit_fun (ctx : extraction_ctx) (fmt : F.formatter)
| HOL4 ->
F.pp_print_string fmt "val _ = assert_return (";
F.pp_print_string fmt "“";
- let fun_name =
- ctx_get_local_function def.def_id def.loop_id def.back_id ctx
- in
+ let fun_name = ctx_get_local_function def.def_id def.loop_id ctx in
F.pp_print_string fmt fun_name;
if sg.inputs <> [] then (
F.pp_print_space fmt ();
diff --git a/compiler/ExtractBase.ml b/compiler/ExtractBase.ml
index 04686705..591e8aab 100644
--- a/compiler/ExtractBase.ml
+++ b/compiler/ExtractBase.ml
@@ -167,11 +167,7 @@ type id =
| TraitImplId of TraitImplId.id
| LocalTraitClauseId of TraitClauseId.id
| TraitDeclConstructorId of TraitDeclId.id
- | TraitMethodId of TraitDeclId.id * string * T.RegionGroupId.id option
- (** Something peculiar with trait methods: because we have to take into
- account forward/backward functions, we may need to generate fields
- items per method.
- *)
+ | TraitMethodId of TraitDeclId.id * string
| TraitItemId of TraitDeclId.id * string
(** A trait associated item which is not a method *)
| TraitParentClauseId of TraitDeclId.id * TraitClauseId.id
@@ -353,8 +349,6 @@ let basename_to_unique (names_set : StringSet.t)
in
if StringSet.mem basename names_set then gen 1 else basename
-type fun_name_info = { keep_fwd : bool; num_backs : int }
-
type names_maps = {
names_map : names_map;
(** The map for id to names, where we forbid name collisions
@@ -384,7 +378,7 @@ let allow_collisions (id : id) : bool =
| FieldId _ | TraitItemClauseId _ | TraitParentClauseId _ | TraitItemId _
| TraitMethodId _ ->
!Config.record_fields_short_names
- | FunId (Pure _ | FromLlbc (FunId (FAssumed _), _, _)) ->
+ | FunId (Pure _ | FromLlbc (FunId (FAssumed _), _)) ->
(* We map several assumed functions to the same id *)
true
| _ -> false
@@ -549,15 +543,6 @@ type extraction_ctx = {
-- makes the if then else dependent
]}
*)
- fun_name_info : fun_name_info PureUtils.RegularFunIdMap.t;
- (** Information used to filter and name functions - we use it
- to print comments in the generated code, to help link
- the generated code to the original code (information such
- as: "this function is the backward function of ...", or
- "this function is the merged forward/backward function of ..."
- in case a Rust function only has one backward translation
- and we filter the forward function because it returns unit.
- *)
trait_decl_id : trait_decl_id option;
(** If we are extracting a trait declaration, identifies it *)
is_provided_method : bool;
@@ -668,14 +653,8 @@ let id_to_string (id : id) (ctx : extraction_ctx) : string =
^ TraitClauseId.to_string clause_id
| TraitItemId (id, name) ->
"trait_item_id: " ^ trait_decl_id_to_string id ^ ", type name: " ^ name
- | TraitMethodId (trait_decl_id, fun_name, rg_id) ->
- let fwd_back_kind =
- match rg_id with
- | None -> "forward"
- | Some rg_id -> "backward " ^ RegionGroupId.to_string rg_id
- in
- trait_decl_id_to_string trait_decl_id
- ^ ", method name (" ^ fwd_back_kind ^ "): " ^ fun_name
+ | TraitMethodId (trait_decl_id, fun_name) ->
+ trait_decl_id_to_string trait_decl_id ^ ", method name: " ^ fun_name
| TraitSelfClauseId -> "trait_self_clause"
let ctx_add (id : id) (name : string) (ctx : extraction_ctx) : extraction_ctx =
@@ -694,8 +673,8 @@ let ctx_get_function (id : fun_id) (ctx : extraction_ctx) : string =
ctx_get (FunId id) ctx
let ctx_get_local_function (id : A.FunDeclId.id) (lp : LoopId.id option)
- (rg : RegionGroupId.id option) (ctx : extraction_ctx) : string =
- ctx_get_function (FromLlbc (FunId (FRegular id), lp, rg)) ctx
+ (ctx : extraction_ctx) : string =
+ ctx_get_function (FromLlbc (FunId (FRegular id), lp)) ctx
let ctx_get_type (id : type_id) (ctx : extraction_ctx) : string =
assert (id <> TTuple);
@@ -733,8 +712,8 @@ let ctx_get_trait_type (id : trait_decl_id) (item_name : string)
ctx_get_trait_item id item_name ctx
let ctx_get_trait_method (id : trait_decl_id) (item_name : string)
- (rg_id : T.RegionGroupId.id option) (ctx : extraction_ctx) : string =
- ctx_get (TraitMethodId (id, item_name, rg_id)) ctx
+ (ctx : extraction_ctx) : string =
+ ctx_get (TraitMethodId (id, item_name)) ctx
let ctx_get_trait_parent_clause (id : trait_decl_id) (clause : trait_clause_id)
(ctx : extraction_ctx) : string =
@@ -1164,8 +1143,7 @@ let initialize_names_maps () : names_maps =
in
let assumed_functions =
List.map
- (fun (fid, rg, name) ->
- (FromLlbc (Pure.FunId (FAssumed fid), None, rg), name))
+ (fun (fid, name) -> (FromLlbc (Pure.FunId (FAssumed fid), None), name))
init.assumed_llbc_functions
@ List.map (fun (fid, name) -> (Pure fid, name)) init.assumed_pure_functions
in
@@ -1408,61 +1386,12 @@ let default_fun_loop_suffix (num_loops : int) (loop_id : LoopId.id option) :
If this function admits only one loop, we omit it. *)
if num_loops = 1 then "_loop" else "_loop" ^ LoopId.to_string loop_id
-(** A helper function: generates a function suffix from a region group
- information.
+(** A helper function: generates a function suffix.
TODO: move all those helpers.
*)
-let default_fun_suffix (num_loops : int) (loop_id : LoopId.id option)
- (num_region_groups : int) (rg : region_group_info option)
- ((keep_fwd, num_backs) : bool * int) : string =
- let lp_suff = default_fun_loop_suffix num_loops loop_id in
-
- (* There are several cases:
- - [rg] is [Some]: this is a forward function:
- - we add "_fwd"
- - [rg] is [None]: this is a backward function:
- - this function has one extracted backward function:
- - if the forward function has been filtered, we add nothing:
- the forward function is useless, so the unique backward function
- takes its place, in a way (in effect, we "merge" the forward
- and the backward functions).
- - otherwise we add "_back"
- - this function has several backward functions: we add "_back" and an
- additional suffix to identify the precise backward function
- Note that we always add a suffix (in case there are no region groups,
- we could not add the "_fwd" suffix) to prevent name clashes between
- definitions (in particular between type and function definitions).
- *)
- let rg_suff =
- (* TODO: make all the backends match what is done for Lean *)
- match rg with
- | None ->
- if
- (* In order to avoid name conflicts:
- * - if the forward is eliminated, we add the suffix "_fwd" (it won't be used)
- * - otherwise, no suffix (because the backward functions will have a suffix)
- *)
- num_backs = 1 && not keep_fwd
- then "_fwd"
- else ""
- | Some rg ->
- assert (num_region_groups > 0 && num_backs > 0);
- if num_backs = 1 then
- (* Exactly one backward function *)
- if not keep_fwd then "" else "_back"
- else if
- (* Several region groups/backward functions:
- - if all the regions in the group have names, we use those names
- - otherwise we use an index
- *)
- List.for_all Option.is_some rg.region_names
- then
- (* Concatenate the region names *)
- "_back" ^ String.concat "" (List.map Option.get rg.region_names)
- else (* Use the region index *)
- "_back" ^ RegionGroupId.to_string rg.id
- in
- lp_suff ^ rg_suff
+let default_fun_suffix (num_loops : int) (loop_id : LoopId.id option) : string =
+ (* We only generate a suffix for the functions we generate from the loops *)
+ default_fun_loop_suffix num_loops loop_id
(** Compute the name of a regular (non-assumed) function.
@@ -1472,24 +1401,13 @@ let default_fun_suffix (num_loops : int) (loop_id : LoopId.id option)
indices to derive unique names for the loops for instance - if there is
exactly one loop, we don't need to use indices)
- loop id (if pertinent)
- - number of region groups
- - region group information in case of a backward function
- ([None] if forward function)
- - pair:
- - do we generate the forward function (it may have been filtered)?
- - the number of *extracted backward functions* (same comment as for
- the number of loops)
- The number of extracted backward functions if not necessarily
- equal to the number of region groups, because we may have
- filtered some of them.
TODO: use the fun id for the assumed functions.
*)
let ctx_compute_fun_name (ctx : extraction_ctx) (fname : llbc_name)
- (num_loops : int) (loop_id : LoopId.id option) (num_rgs : int)
- (rg : region_group_info option) (filter_info : bool * int) : string =
+ (num_loops : int) (loop_id : LoopId.id option) : string =
let fname = ctx_compute_fun_name_no_suffix ctx fname in
(* Compute the suffix *)
- let suffix = default_fun_suffix num_loops loop_id num_rgs rg filter_info in
+ let suffix = default_fun_suffix num_loops loop_id in
(* Concatenate *)
fname ^ suffix
@@ -1963,61 +1881,26 @@ let ctx_add_global_decl_and_body (def : A.global_decl) (ctx : extraction_ctx) :
| None ->
(* Not the case: "standard" registration *)
let name = ctx_compute_global_name ctx def.name in
- let body = FunId (FromLlbc (FunId (FRegular def.body), None, None)) in
+ let body = FunId (FromLlbc (FunId (FRegular def.body), None)) in
let ctx = ctx_add decl (name ^ "_c") ctx in
let ctx = ctx_add body (name ^ "_body") ctx in
ctx
-let ctx_compute_fun_name (trans_group : pure_fun_translation) (def : fun_decl)
- (ctx : extraction_ctx) : string =
- (* Lookup the LLBC def to compute the region group information *)
- let def_id = def.def_id in
- let llbc_def = A.FunDeclId.Map.find def_id ctx.trans_ctx.fun_ctx.fun_decls in
- let sg = llbc_def.signature in
- let regions_hierarchy =
- LlbcAstUtils.FunIdMap.find (FRegular def_id)
- ctx.trans_ctx.fun_ctx.regions_hierarchies
- in
- let num_rgs = List.length regions_hierarchy in
- let { keep_fwd; fwd = _; backs } = trans_group in
- let num_backs = List.length backs in
- let rg_info =
- match def.back_id with
- | None -> None
- | Some rg_id ->
- let rg = T.RegionGroupId.nth regions_hierarchy rg_id in
- let region_names =
- List.map
- (fun rid -> (T.RegionVarId.nth sg.generics.regions rid).name)
- rg.regions
- in
- Some { id = rg_id; region_names }
- in
+let ctx_compute_fun_name (def : fun_decl) (ctx : extraction_ctx) : string =
(* Add the function name *)
- ctx_compute_fun_name ctx def.llbc_name def.num_loops def.loop_id num_rgs
- rg_info (keep_fwd, num_backs)
+ ctx_compute_fun_name ctx def.llbc_name def.num_loops def.loop_id
(* TODO: move to Extract *)
-let ctx_add_fun_decl (trans_group : pure_fun_translation) (def : fun_decl)
- (ctx : extraction_ctx) : extraction_ctx =
+let ctx_add_fun_decl (def : fun_decl) (ctx : extraction_ctx) : extraction_ctx =
(* Sanity check: the function should not be a global body - those are handled
* separately *)
assert (not def.is_global_decl_body);
(* Lookup the LLBC def to compute the region group information *)
let def_id = def.def_id in
- let { keep_fwd; fwd = _; backs } = trans_group in
- let num_backs = List.length backs in
(* Add the function name *)
- let def_name = ctx_compute_fun_name trans_group def ctx in
- let fun_id = (Pure.FunId (FRegular def_id), def.loop_id, def.back_id) in
- let ctx = ctx_add (FunId (FromLlbc fun_id)) def_name ctx in
- (* Add the name info *)
- {
- ctx with
- fun_name_info =
- PureUtils.RegularFunIdMap.add fun_id { keep_fwd; num_backs }
- ctx.fun_name_info;
- }
+ let def_name = ctx_compute_fun_name def ctx in
+ let fun_id = (Pure.FunId (FRegular def_id), def.loop_id) in
+ ctx_add (FunId (FromLlbc fun_id)) def_name ctx
let ctx_compute_type_decl_name (ctx : extraction_ctx) (def : type_decl) : string
=
diff --git a/compiler/ExtractBuiltin.ml b/compiler/ExtractBuiltin.ml
index 3ea5655a..88de31fe 100644
--- a/compiler/ExtractBuiltin.ml
+++ b/compiler/ExtractBuiltin.ml
@@ -221,12 +221,11 @@ type builtin_fun_info = { extract_name : string } [@@deriving show]
parameters. For instance, in the case of the `Vec` functions, there is
a type parameter for the allocator to use, which we want to filter.
*)
-let builtin_funs () : (pattern * bool list option * builtin_fun_info list) list
- =
+let builtin_funs () : (pattern * bool list option * builtin_fun_info) list =
(* Small utility *)
let mk_fun (rust_name : string) (extract_name : string option)
(filter : bool list option) :
- pattern * bool list option * builtin_fun_info list =
+ pattern * bool list option * builtin_fun_info =
let rust_name =
try parse_pattern rust_name
with Failure _ ->
@@ -238,7 +237,7 @@ let builtin_funs () : (pattern * bool list option * builtin_fun_info list) list
| Some name -> split_on_separator name
in
let basename = flatten_name extract_name in
- let f = [ { extract_name = basename } ] in
+ let f = { extract_name = basename } in
(rust_name, filter, f)
in
[
@@ -377,7 +376,7 @@ type builtin_trait_decl_info = {
- a Rust name
- an extraction name
- a list of clauses *)
- methods : (string * builtin_fun_info list) list;
+ methods : (string * builtin_fun_info) list;
}
[@@deriving show]
@@ -418,7 +417,7 @@ let builtin_trait_decls_info () =
if !record_fields_short_names then item_name
else extract_name ^ "_" ^ item_name
in
- let fwd = [ { extract_name = basename } ] in
+ let fwd = { extract_name = basename } in
(item_name, fwd)
in
List.map mk_method methods
diff --git a/compiler/ExtractTypes.ml b/compiler/ExtractTypes.ml
index a3dbf3cc..05b71b9f 100644
--- a/compiler/ExtractTypes.ml
+++ b/compiler/ExtractTypes.ml
@@ -272,7 +272,7 @@ let start_fun_decl_group (ctx : extraction_ctx) (fmt : F.formatter)
if is_single_opaque_fun_decl_group dg then ()
else
let compute_fun_def_name (def : Pure.fun_decl) : string =
- ctx_get_local_function def.def_id def.loop_id def.back_id ctx ^ "_def"
+ ctx_get_local_function def.def_id def.loop_id ctx ^ "_def"
in
let names = List.map compute_fun_def_name dg in
(* Add a break before *)
diff --git a/compiler/Main.ml b/compiler/Main.ml
index 664ec067..3f5e62ad 100644
--- a/compiler/Main.ml
+++ b/compiler/Main.ml
@@ -72,12 +72,6 @@ let () =
Arg.Symbol (backend_names, set_backend),
" Specify the target backend" );
("-dest", Arg.Set_string dest_dir, " Specify the output directory");
- ( "-no-filter-useless-calls",
- Arg.Clear filter_useless_monadic_calls,
- " Do not filter the useless function calls" );
- ( "-no-filter-useless-funs",
- Arg.Clear filter_useless_functions,
- " Do not filter the useless forward/backward functions" );
( "-test-units",
Arg.Set test_unit_functions,
" Test the unit functions with the concrete (i.e., not symbolic) \
diff --git a/compiler/PrintPure.ml b/compiler/PrintPure.ml
index 66475d02..21ca7f08 100644
--- a/compiler/PrintPure.ml
+++ b/compiler/PrintPure.ml
@@ -462,21 +462,13 @@ let inst_fun_sig_to_string (env : fmt_env) (sg : inst_fun_sig) : string =
let all_types = List.append inputs [ output ] in
String.concat " -> " all_types
-let fun_suffix (lp_id : LoopId.id option) (rg_id : T.RegionGroupId.id option) :
- string =
+let fun_suffix (lp_id : LoopId.id option) : string =
let lp_suff =
match lp_id with
| None -> ""
| Some lp_id -> "^loop" ^ LoopId.to_string lp_id
in
-
- let rg_suff =
- match rg_id with
- | None -> ""
- | Some rg_id -> "@" ^ T.RegionGroupId.to_string rg_id
- in
-
- lp_suff ^ rg_suff
+ lp_suff
let llbc_assumed_fun_id_to_string (fid : A.assumed_fun_id) : string =
match fid with
@@ -505,7 +497,7 @@ let pure_assumed_fun_id_to_string (fid : pure_assumed_fun_id) : string =
let regular_fun_id_to_string (env : fmt_env) (fun_id : fun_id) : string =
match fun_id with
- | FromLlbc (fid, lp_id, rg_id) ->
+ | FromLlbc (fid, lp_id) ->
let f =
match fid with
| FunId (FRegular fid) -> fun_decl_id_to_string env fid
@@ -513,7 +505,7 @@ let regular_fun_id_to_string (env : fmt_env) (fun_id : fun_id) : string =
| TraitMethod (trait_ref, method_name, _) ->
trait_ref_to_string env true trait_ref ^ "." ^ method_name
in
- f ^ fun_suffix lp_id rg_id
+ f ^ fun_suffix lp_id
| Pure fid -> pure_assumed_fun_id_to_string fid
let unop_to_string (unop : unop) : string =
@@ -746,7 +738,7 @@ and emeta_to_string (env : fmt_env) (meta : emeta) : string =
let fun_decl_to_string (env : fmt_env) (def : fun_decl) : string =
let env = { env with generics = def.signature.generics } in
- let name = def.name ^ fun_suffix def.loop_id def.back_id in
+ let name = def.name ^ fun_suffix def.loop_id in
let signature = fun_sig_to_string env def.signature in
match def.body with
| None -> "val " ^ name ^ " :\n " ^ signature
diff --git a/compiler/Pure.ml b/compiler/Pure.ml
index a879ba37..dd7a4acf 100644
--- a/compiler/Pure.ml
+++ b/compiler/Pure.ml
@@ -560,8 +560,7 @@ type fun_id_or_trait_method_ref =
[@@deriving show, ord]
(** A function id for a non-assumed function *)
-type regular_fun_id =
- fun_id_or_trait_method_ref * LoopId.id option * RegionGroupId.id option
+type regular_fun_id = fun_id_or_trait_method_ref * LoopId.id option
[@@deriving show, ord]
(** A function identifier *)
@@ -1078,7 +1077,6 @@ type fun_decl = {
*)
loop_id : LoopId.id option;
(** [Some] if this definition was generated for a loop *)
- back_id : RegionGroupId.id option;
llbc_name : llbc_name; (** The original LLBC name. *)
name : string;
(** We use the name only for printing purposes (for debugging):
diff --git a/compiler/PureMicroPasses.ml b/compiler/PureMicroPasses.ml
index ec64df21..04bc90d7 100644
--- a/compiler/PureMicroPasses.ml
+++ b/compiler/PureMicroPasses.ml
@@ -925,156 +925,9 @@ let inline_useless_var_reassignments (ctx : trans_ctx) ~(inline_named : bool)
in
{ def with body = Some body }
-(** For the cases where we split the forward/backward functions.
-
- Given a forward or backward function call, is there, for every execution
- path, a child backward function called later with exactly the same input
- list prefix. We use this to filter useless function calls: if there are
- such child calls, we can remove this one (in case its outputs are not
- used).
- We do this check because we can't simply remove function calls whose
- outputs are not used, as they might fail. However, if a function fails,
- its children backward functions then fail on the same inputs (ignoring
- the additional inputs those receive).
-
- For instance, if we have:
- {[
- fn f<'a>(x : &'a mut T);
- ]}
-
- We often have things like this in the synthesized code:
- {[
- _ <-- f@fwd x;
- ...
- nx <-- f@back'a x y;
- ...
- ]}
-
- If [f@back'a x y] fails, then necessarily [f@fwd x] also fails.
- In this situation, we can remove the call [f@fwd x].
- *)
-let expression_contains_child_call_in_all_paths (ctx : trans_ctx)
- (id0 : fun_id_or_trait_method_ref) (lp_id0 : LoopId.id option)
- (rg_id0 : T.RegionGroupId.id option) (generics0 : generic_args)
- (args0 : texpression list) (e : texpression) : bool =
- let check_call (fun_id1 : fun_or_op_id) (generics1 : generic_args)
- (args1 : texpression list) : bool =
- (* Check the fun_ids, to see if call1's function is a child of call0's function *)
- match fun_id1 with
- | Fun (FromLlbc (id1, lp_id1, rg_id1)) ->
- (* Both are "regular" calls: check if they come from the same rust function *)
- if id0 = id1 && lp_id0 = lp_id1 then
- (* Same rust functions: check the regions hierarchy *)
- let call1_is_child =
- match (rg_id0, rg_id1) with
- | None, _ ->
- (* The function used in call0 is the forward function: the one
- * used in call1 is necessarily a child *)
- true
- | Some _, None ->
- (* Opposite of previous case *)
- false
- | Some rg_id0, Some rg_id1 ->
- if rg_id0 = rg_id1 then true
- else
- (* We need to use the regions hierarchy *)
- let regions_hierarchy =
- let id0 =
- match id0 with
- | FunId fun_id -> fun_id
- | TraitMethod (_, _, fun_decl_id) -> FRegular fun_decl_id
- in
- LlbcAstUtils.FunIdMap.find id0
- ctx.fun_ctx.regions_hierarchies
- in
- (* Compute the set of ancestors of the function in call1 *)
- let call1_ancestors =
- LlbcAstUtils.list_ancestor_region_groups regions_hierarchy
- rg_id1
- in
- (* Check if the function used in call0 is inside *)
- T.RegionGroupId.Set.mem rg_id0 call1_ancestors
- in
- (* If call1 is a child, then we need to check if the input arguments
- * used in call0 are a prefix of the input arguments used in call1
- * (note call1 being a child, it will likely consume strictly more
- * given back values).
- * *)
- if call1_is_child then
- let call1_args =
- Collections.List.prefix (List.length args0) args1
- in
- let args = List.combine args0 call1_args in
- (* Note that the input values are expressions, *which may contain
- * meta-values* (which we need to ignore). *)
- let input_eq (v0, v1) =
- PureUtils.remove_meta v0 = PureUtils.remove_meta v1
- in
- (* Compare the generics and the prefix of the input arguments *)
- generics0 = generics1 && List.for_all input_eq args
- else (* Not a child *)
- false
- else (* Not the same function *)
- false
- | _ -> false
- in
-
- let visitor =
- object (self)
- inherit [_] reduce_expression
- method zero _ = false
- method plus b0 b1 _ = b0 () && b1 ()
-
- method! visit_texpression env e =
- match e.e with
- | Var _ | CVar _ | Const _ -> fun _ -> false
- | StructUpdate _ ->
- (* There shouldn't be monadic calls in structure updates - also
- note that by returning [false] we are conservative: we might
- *prevent* possible optimisations (i.e., filtering some function
- calls), which is sound. *)
- fun _ -> false
- | Let (_, _, re, e) -> (
- match opt_destruct_function_call re with
- | None -> fun () -> self#visit_texpression env e ()
- | Some (func1, generics1, args1) ->
- let call_is_child = check_call func1 generics1 args1 in
- if call_is_child then fun () -> true
- else fun () -> self#visit_texpression env e ())
- | Lambda (_, e) -> self#visit_texpression env e
- | App _ -> (
- fun () ->
- match opt_destruct_function_call e with
- | Some (func1, tys1, args1) -> check_call func1 tys1 args1
- | None -> false)
- | Qualif _ ->
- (* Note that this case includes functions without arguments *)
- fun () -> false
- | Meta (_, e) -> self#visit_texpression env e
- | Loop loop ->
- (* We only visit the *function end* *)
- self#visit_texpression env loop.fun_end
- | Switch (_, body) -> self#visit_switch_body env body
-
- method! visit_switch_body env body =
- match body with
- | If (e1, e2) ->
- fun () ->
- self#visit_texpression env e1 ()
- && self#visit_texpression env e2 ()
- | Match branches ->
- fun () ->
- List.for_all
- (fun br -> self#visit_texpression env br.branch ())
- branches
- end
- in
- visitor#visit_texpression () e ()
-
(** Filter the useless assignments (removes the useless variables, filters
the function calls) *)
-let filter_useless (filter_monadic_calls : bool) (ctx : trans_ctx)
- (def : fun_decl) : fun_decl =
+let filter_useless (_ctx : trans_ctx) (def : fun_decl) : fun_decl =
(* We first need a transformation on *left-values*, which filters the useless
* variables and tells us whether the value contains any variable which has
* not been replaced by [_] (in which case we need to keep the assignment,
@@ -1166,30 +1019,8 @@ let filter_useless (filter_monadic_calls : bool) (ctx : trans_ctx)
if not monadic then
(* Not a monadic let-binding: simple case *)
(e.e, fun _ -> used)
- else
- (* Monadic let-binding: trickier.
- * We can filter if the right-expression is a function call,
- * under some conditions. *)
- match (filter_monadic_calls, opt_destruct_function_call re) with
- | true, Some (Fun (FromLlbc (fid, lp_id, rg_id)), tys, args) ->
- (* If we split the forward/backward functions.
-
- We need to check if there is a child call - see
- the comments for:
- [expression_contains_child_call_in_all_paths] *)
- if not !Config.return_back_funs then
- let has_child_call =
- expression_contains_child_call_in_all_paths ctx fid
- lp_id rg_id tys args e
- in
- if has_child_call then (* Filter *)
- (e.e, fun _ -> used)
- else (* No child call: don't filter *)
- dont_filter ()
- else dont_filter ()
- | _ ->
- (* Not an LLBC function call or not allowed to filter: we can't filter *)
- dont_filter ()
+ else (* Monadic let-binding: can't filter *)
+ dont_filter ()
else (* There are used variables: don't filter *)
dont_filter ()
| Loop loop ->
@@ -1442,22 +1273,6 @@ let simplify_aggregates (ctx : trans_ctx) (def : fun_decl) : fun_decl =
let body = { body with body = body_exp } in
{ def with body = Some body }
-(** Return [None] if the function is a backward function with no outputs (so
- that we eliminate the definition which is useless).
-
- Note that the calls to such functions are filtered when translating from
- symbolic to pure. Here, we remove the definitions altogether, because they
- are now useless
- *)
-let filter_if_backward_with_no_outputs (def : fun_decl) : fun_decl option =
- if
- !Config.filter_useless_functions
- && Option.is_some def.back_id
- && def.signature.output = mk_result_ty mk_unit_ty
- || def.signature.output = mk_unit_ty
- then None
- else Some def
-
(** Retrieve the loop definitions from the function definition.
{!SymbolicToPure} generates an AST in which the loop bodies are part of
@@ -1530,14 +1345,7 @@ let decompose_loops (_ctx : trans_ctx) (def : fun_decl) :
info.num_inputs_with_fuel_no_state
info.num_inputs_with_fuel_with_state
in
- let back_inputs =
- if !Config.return_back_funs then []
- else
- snd
- (Collections.List.split_at fun_sig.inputs
- info.num_inputs_with_fuel_with_state)
- in
- List.concat [ fuel; fwd_inputs; fwd_state; back_inputs ]
+ List.concat [ fuel; fwd_inputs; fwd_state ]
in
let output = loop.output_ty in
@@ -1618,7 +1426,6 @@ let decompose_loops (_ctx : trans_ctx) (def : fun_decl) :
kind = def.kind;
num_loops;
loop_id = Some loop.loop_id;
- back_id = def.back_id;
llbc_name = def.llbc_name;
name = def.name;
signature = loop_sig;
@@ -1640,35 +1447,6 @@ let decompose_loops (_ctx : trans_ctx) (def : fun_decl) :
let loops = List.map snd (LoopId.Map.bindings !loops) in
(def, loops)
-(** Return [false] if the forward function is useless and should be filtered.
-
- - a forward function with no output (comes from a Rust function with
- unit return type)
- - the function has mutable borrows as inputs (which is materialized
- by the fact we generated backward functions which were not filtered).
-
- In such situation, every call to the Rust function will be translated to:
- - a call to the forward function which returns nothing
- - calls to the backward functions
- As a failing backward function implies the forward function also fails,
- we can filter the calls to the forward function, which thus becomes
- useless.
- In such situation, we can remove the forward function definition
- altogether.
- *)
-let keep_forward (fwd : fun_and_loops) (backs : fun_and_loops list) : bool =
- (* The question of filtering the forward functions arises only if we split
- the forward/backward functions *)
- if !Config.return_back_funs then true
- else if
- (* Note that at this point, the output types are no longer seen as tuples:
- * they should be lists of length 1. *)
- !Config.filter_useless_functions
- && fwd.f.signature.output = mk_result_ty mk_unit_ty
- && backs <> []
- then false
- else true
-
(** Convert the unit variables to [()] if they are used as right-values or
[_] if they are used as left values in patterns. *)
let unit_vars_to_unit (def : fun_decl) : fun_decl =
@@ -1724,19 +1502,17 @@ let eliminate_box_functions (_ctx : trans_ctx) (def : fun_decl) : fun_decl =
* could have: [box_new f x])
* *)
match fun_id with
- | Fun (FromLlbc (FunId (FAssumed aid), _lp_id, rg_id)) -> (
- match (aid, rg_id) with
- | BoxNew, _ ->
- assert (rg_id = None);
+ | Fun (FromLlbc (FunId (FAssumed aid), _lp_id)) -> (
+ match aid with
+ | BoxNew ->
let arg, args = Collections.List.pop args in
mk_apps arg args
- | BoxFree, _ ->
+ | BoxFree ->
assert (args = []);
mk_unit_rvalue
- | ( ( SliceIndexShared | SliceIndexMut | ArrayIndexShared
- | ArrayIndexMut | ArrayToSliceShared | ArrayToSliceMut
- | ArrayRepeat ),
- _ ) ->
+ | SliceIndexShared | SliceIndexMut | ArrayIndexShared
+ | ArrayIndexMut | ArrayToSliceShared | ArrayToSliceMut
+ | ArrayRepeat ->
super#visit_texpression env e)
| _ -> super#visit_texpression env e)
| _ -> super#visit_texpression env e
@@ -1989,7 +1765,7 @@ let apply_end_passes_to_def (ctx : trans_ctx) (def : fun_decl) : fun_decl =
(lazy ("eliminate_box_functions:\n\n" ^ fun_decl_to_string ctx def ^ "\n"));
(* Filter the useless variables, assignments, function calls, etc. *)
- let def = filter_useless !Config.filter_useless_monadic_calls ctx def in
+ let def = filter_useless ctx def in
log#ldebug (lazy ("filter_useless:\n\n" ^ fun_decl_to_string ctx def ^ "\n"));
(* Simplify the lets immediately followed by a return.
@@ -2130,16 +1906,7 @@ let filter_loop_inputs (transl : pure_fun_translation list) :
*)
let all_decls =
List.concat
- (List.concat
- (List.concat
- (List.map
- (fun { fwd; backs; _ } ->
- [ fwd.f :: fwd.loops ]
- :: List.map
- (fun { f = back; loops = loops_back } ->
- [ back :: loops_back ])
- backs)
- transl)))
+ (List.concat (List.map (fun { f; loops } -> [ f :: loops ]) transl))
in
let subgroups = ReorderDecls.group_reorder_fun_decls all_decls in
@@ -2207,7 +1974,7 @@ let filter_loop_inputs (transl : pure_fun_translation list) :
match e_app.e with
| Qualif qualif -> (
match qualif.id with
- | FunOrOp (Fun (FromLlbc (FunId fun_id', loop_id', _))) ->
+ | FunOrOp (Fun (FromLlbc (FunId fun_id', loop_id'))) ->
if (fun_id', loop_id') = fun_id then (
(* For each argument, check if it is exactly the original
input parameter. Note that there shouldn't be partial
@@ -2357,8 +2124,7 @@ let filter_loop_inputs (transl : pure_fun_translation list) :
match e_app.e with
| Qualif qualif -> (
match qualif.id with
- | FunOrOp (Fun (FromLlbc (FunId fun_id, loop_id, _)))
- -> (
+ | FunOrOp (Fun (FromLlbc (FunId fun_id, loop_id))) -> (
match
FunLoopIdMap.find_opt (fun_id, loop_id) !used_map
with
@@ -2400,13 +2166,8 @@ let filter_loop_inputs (transl : pure_fun_translation list) :
in
let transl =
List.map
- (fun trans ->
- let filter_fun_and_loops f =
- { f = filter_in_one f.f; loops = List.map filter_in_one f.loops }
- in
- let fwd = filter_fun_and_loops trans.fwd in
- let backs = List.map filter_fun_and_loops trans.backs in
- { trans with fwd; backs })
+ (fun f ->
+ { f = filter_in_one f.f; loops = List.map filter_in_one f.loops })
transl
in
@@ -2420,18 +2181,11 @@ let filter_loop_inputs (transl : pure_fun_translation list) :
it thus returns the pair: (function def, loop defs). See {!decompose_loops}
for more information.
- Will return [None] if the function is a backward function with no outputs.
-
[ctx]: used only for printing.
*)
-let apply_passes_to_def (ctx : trans_ctx) (def : fun_decl) :
- fun_and_loops option =
+let apply_passes_to_def (ctx : trans_ctx) (def : fun_decl) : fun_and_loops =
(* Debug *)
- log#ldebug
- (lazy
- ("PureMicroPasses.apply_passes_to_def: " ^ def.name ^ " ("
- ^ Print.option_to_string T.RegionGroupId.to_string def.back_id
- ^ ")"));
+ log#ldebug (lazy ("PureMicroPasses.apply_passes_to_def: " ^ def.name));
log#ldebug (lazy ("original decl:\n\n" ^ fun_decl_to_string ctx def ^ "\n"));
@@ -2451,29 +2205,13 @@ let apply_passes_to_def (ctx : trans_ctx) (def : fun_decl) :
let def = remove_meta def in
log#ldebug (lazy ("remove_meta:\n\n" ^ fun_decl_to_string ctx def ^ "\n"));
- (* Remove the backward functions with no outputs.
+ (* Extract the loop definitions by removing the {!Loop} node *)
+ let def, loops = decompose_loops ctx def in
- Note that the *calls* to those functions should already have been removed,
- when translating from symbolic to pure. Here, we remove the definitions
- altogether, because they are now useless *)
- let name = def.name ^ PrintPure.fun_suffix def.loop_id def.back_id in
- let opt_def = filter_if_backward_with_no_outputs def in
-
- match opt_def with
- | None ->
- log#ldebug (lazy ("filtered (backward with no outputs): " ^ name ^ "\n"));
- None
- | Some def ->
- log#ldebug
- (lazy ("not filtered (not backward with no outputs): " ^ name ^ "\n"));
-
- (* Extract the loop definitions by removing the {!Loop} node *)
- let def, loops = decompose_loops ctx def in
-
- (* Apply the remaining passes *)
- let f = apply_end_passes_to_def ctx def in
- let loops = List.map (apply_end_passes_to_def ctx) loops in
- Some { f; loops }
+ (* Apply the remaining passes *)
+ let f = apply_end_passes_to_def ctx def in
+ let loops = List.map (apply_end_passes_to_def ctx) loops in
+ { f; loops }
(** Apply the micro-passes to a list of forward/backward translations.
@@ -2489,18 +2227,11 @@ let apply_passes_to_def (ctx : trans_ctx) (def : fun_decl) :
but convenient.
*)
let apply_passes_to_pure_fun_translations (ctx : trans_ctx)
- (transl : (fun_decl * fun_decl list) list) : pure_fun_translation list =
- let apply_to_one (trans : fun_decl * fun_decl list) : pure_fun_translation =
- (* Apply the passes to the individual functions *)
- let fwd, backs = trans in
- let fwd = Option.get (apply_passes_to_def ctx fwd) in
- let backs = List.filter_map (apply_passes_to_def ctx) backs in
- (* Compute whether we need to filter the forward function or not *)
- let keep_fwd = keep_forward fwd backs in
- { keep_fwd; fwd; backs }
- in
-
- let transl = List.map apply_to_one transl in
+ (transl : fun_decl list) : pure_fun_translation list =
+ (* Apply the micro-passes *)
+ let transl = List.map (apply_passes_to_def ctx) transl in
- (* Filter the useless inputs in the loop functions *)
+ (* Filter the useless inputs in the loop functions (loops are initially
+ parameterized by *all* the symbolic values in the context, because
+ they may access any of them). *)
filter_loop_inputs transl
diff --git a/compiler/ReorderDecls.ml b/compiler/ReorderDecls.ml
index 53c94ff4..f5443e03 100644
--- a/compiler/ReorderDecls.ml
+++ b/compiler/ReorderDecls.ml
@@ -5,11 +5,7 @@ open Pure
(** The local logger *)
let log = Logging.reorder_decls_log
-type fun_id = {
- def_id : FunDeclId.id;
- lp_id : LoopId.id option;
- rg_id : T.RegionGroupId.id option;
-}
+type fun_id = { def_id : FunDeclId.id; lp_id : LoopId.id option }
[@@deriving show, ord]
module FunIdOrderedType : OrderedType with type t = fun_id = struct
@@ -43,11 +39,11 @@ let compute_body_fun_deps (e : texpression) : FunIdSet.t =
| FunOrOp (Fun fid) -> (
match fid with
| Pure _ -> ()
- | FromLlbc (fid, lp_id, rg_id) -> (
+ | FromLlbc (fid, lp_id) -> (
match fid with
| FunId (FAssumed _) -> ()
| TraitMethod (_, _, fid) | FunId (FRegular fid) ->
- let id = { def_id = fid; lp_id; rg_id } in
+ let id = { def_id = fid; lp_id } in
ids := FunIdSet.add id !ids))
end
in
@@ -71,7 +67,7 @@ let group_reorder_fun_decls (decls : fun_decl list) :
(bool * fun_decl list) list =
let module IntMap = MakeMap (OrderedInt) in
let get_fun_id (decl : fun_decl) : fun_id =
- { def_id = decl.def_id; lp_id = decl.loop_id; rg_id = decl.back_id }
+ { def_id = decl.def_id; lp_id = decl.loop_id }
in
(* Compute the list/set of identifiers *)
let idl = List.map get_fun_id decls in
diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml
index 859d6f17..2db5f66c 100644
--- a/compiler/SymbolicToPure.ml
+++ b/compiler/SymbolicToPure.ml
@@ -2035,7 +2035,7 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) :
| S.Fun (fid, call_id) ->
(* Regular function call *)
let fid_t = translate_fun_id_or_trait_method_ref ctx fid in
- let func = Fun (FromLlbc (fid_t, None, None)) in
+ let func = Fun (FromLlbc (fid_t, None)) in
(* Retrieve the effect information about this function (can fail,
* takes a state as input, etc.) *)
let effect_info = get_fun_effect_info ctx fid None None in
@@ -2539,8 +2539,6 @@ and translate_end_abstraction_loop (ectx : C.eval_ctx) (abs : V.abs)
get_fun_effect_info ctx (FunId fun_id) (Some vloop_id) (Some rg_id)
in
let loop_info = LoopId.Map.find loop_id ctx.loops in
- let generics = loop_info.generics in
- let fwd_inputs = Option.get loop_info.forward_inputs in
(* Retrieve the additional backward inputs. Note that those are actually
the backward inputs of the function we are synthesizing (and that we
need to *transmit* to the loop backward function): they are not the
@@ -2582,10 +2580,6 @@ and translate_end_abstraction_loop (ectx : C.eval_ctx) (abs : V.abs)
(fun (arg, mp) -> mk_opt_mplace_texpression mp arg)
(List.combine inputs args_mplaces)
in
- let input_tys = (List.map (fun (x : texpression) -> x.ty)) args in
- let ret_ty =
- if effect_info.can_fail then mk_result_ty output.ty else output.ty
- in
(* Create the expression for the function:
- it is either a call to a top-level function, if we split the
forward/backward functions
@@ -3218,7 +3212,7 @@ and translate_forward_end (ectx : C.eval_ctx)
let out_pat = mk_simpl_tuple_pattern out_pats in
let loop_call =
- let fun_id = Fun (FromLlbc (FunId fid, Some loop_id, None)) in
+ let fun_id = Fun (FromLlbc (FunId fid, Some loop_id)) in
let func = { id = FunOrOp fun_id; generics = loop_info.generics } in
let input_tys = (List.map (fun (x : texpression) -> x.ty)) args in
let ret_ty =
@@ -3567,9 +3561,6 @@ let translate_fun_decl (ctx : bs_ctx) (body : S.expression option) : fun_decl =
let name = name_to_string ctx llbc_name in
(* Translate the signature *)
let signature = translate_fun_sig_from_decomposed ctx.sg in
- let regions_hierarchy =
- FunIdMap.find (FRegular def_id) ctx.fun_ctx.regions_hierarchies
- in
(* Translate the body, if there is *)
let body =
match body with
diff --git a/compiler/Translate.ml b/compiler/Translate.ml
index 55a94302..c12de045 100644
--- a/compiler/Translate.ml
+++ b/compiler/Translate.ml
@@ -1,5 +1,4 @@
open Interpreter
-open Expressions
open Types
open Values
open LlbcAst
@@ -49,8 +48,6 @@ let translate_function_to_pure (trans_ctx : trans_ctx)
log#ldebug
(lazy ("translate_function_to_pure: " ^ name_to_string trans_ctx fdef.name));
- let def_id = fdef.def_id in
-
(* Compute the symbolic ASTs, if the function is transparent *)
let symbolic_trans = translate_function_to_symbolics trans_ctx fdef in
@@ -124,20 +121,7 @@ let translate_function_to_pure (trans_ctx : trans_ctx)
SymbolicToPure.translate_fun_sig_from_decl_to_decomposed trans_ctx fdef
in
- let regions_hierarchy =
- LlbcAstUtils.FunIdMap.find (FRegular def_id) fun_ctx.regions_hierarchies
- in
-
- let var_counter, back_state_vars =
- if !Config.return_back_funs then (var_counter, [])
- else
- List.fold_left_map
- (fun var_counter (region_vars : region_var_group) ->
- let gid = region_vars.id in
- let var, var_counter = Pure.VarId.fresh var_counter in
- (var_counter, (gid, var)))
- var_counter regions_hierarchy
- in
+ let var_counter, back_state_vars = (var_counter, []) in
let back_state_vars = RegionGroupId.Map.of_list back_state_vars in
let ctx =
@@ -195,28 +179,7 @@ let translate_function_to_pure (trans_ctx : trans_ctx)
in
(* Add the backward inputs *)
- let ctx, backward_inputs_no_state, backward_inputs_with_state =
- if !Config.return_back_funs then (ctx, [], [])
- else
- let ctx, inputs_no_with_state =
- List.fold_left_map
- (fun ctx (region_vars : region_var_group) ->
- let gid = region_vars.id in
- let back_sg = RegionGroupId.Map.find gid sg.back_sg in
- let ctx, no_state =
- SymbolicToPure.fresh_vars back_sg.inputs_no_state ctx
- in
- let ctx, with_state =
- SymbolicToPure.fresh_vars back_sg.inputs ctx
- in
- (ctx, ((gid, no_state), (gid, with_state))))
- ctx regions_hierarchy
- in
- let inputs_no_state, inputs_with_state =
- List.split inputs_no_with_state
- in
- (ctx, inputs_no_state, inputs_with_state)
- in
+ let backward_inputs_no_state, backward_inputs_with_state = ([], []) in
let backward_inputs_no_state =
RegionGroupId.Map.of_list backward_inputs_no_state
in
@@ -225,40 +188,10 @@ let translate_function_to_pure (trans_ctx : trans_ctx)
in
let ctx = { ctx with backward_inputs_no_state; backward_inputs_with_state } in
- (* Translate the forward function *)
- let pure_forward =
- match symbolic_trans with
- | None -> SymbolicToPure.translate_fun_decl ctx None
- | Some (_, ast) -> SymbolicToPure.translate_fun_decl ctx (Some ast)
- in
-
- (* Translate the backward functions, if we split the forward and backward functions *)
- let translate_backward (rg : region_var_group) : Pure.fun_decl =
- (* For the backward inputs/outputs initialization: we use the fact that
- * there are no nested borrows for now, and so that the region groups
- * can't have parents *)
- assert (rg.parents = []);
- let back_id = rg.id in
-
- match symbolic_trans with
- | None ->
- (* Initialize the context *)
- let ctx = { ctx with bid = Some back_id } in
- (* Translate *)
- SymbolicToPure.translate_fun_decl ctx None
- | Some (_, symbolic) ->
- (* Initialize the context *)
- let ctx = { ctx with bid = Some back_id } in
- (* Translate *)
- SymbolicToPure.translate_fun_decl ctx (Some symbolic)
- in
- let pure_backwards =
- if !Config.return_back_funs then []
- else List.map translate_backward regions_hierarchy
- in
-
- (* Return *)
- (pure_forward, pure_backwards)
+ (* Translate the function *)
+ match symbolic_trans with
+ | None -> SymbolicToPure.translate_fun_decl ctx None
+ | Some (_, ast) -> SymbolicToPure.translate_fun_decl ctx (Some ast)
(* TODO: factor out the return type *)
let translate_crate_to_pure (crate : crate) :
@@ -513,9 +446,8 @@ let export_global (fmt : Format.formatter) (config : gen_config) (ctx : gen_ctx)
let global_decls = ctx.trans_ctx.global_ctx.global_decls in
let global = GlobalDeclId.Map.find id global_decls in
let trans = FunDeclId.Map.find global.body ctx.trans_funs in
- assert (trans.fwd.loops = []);
- assert (trans.backs = []);
- let body = trans.fwd.f in
+ assert (trans.loops = []);
+ let body = trans.f in
let is_opaque = Option.is_none body.Pure.body in
(* Check if we extract the global *)
@@ -643,7 +575,7 @@ let export_functions_group (fmt : Format.formatter) (config : gen_config)
let funs_map = builtin_funs_map () in
List.map
(fun (trans : pure_fun_translation) ->
- match_name_find_opt ctx.trans_ctx trans.fwd.f.llbc_name funs_map <> None)
+ match_name_find_opt ctx.trans_ctx trans.f.llbc_name funs_map <> None)
pure_ls
in
@@ -660,7 +592,7 @@ let export_functions_group (fmt : Format.formatter) (config : gen_config)
(* Extract the decrease clauses template bodies *)
if config.extract_template_decreases_clauses then
List.iter
- (fun { fwd; _ } ->
+ (fun f ->
(* We only generate decreases clauses for the forward functions, because
the termination argument should only depend on the forward inputs.
The backward functions thus use the same decreases clauses as the
@@ -687,27 +619,14 @@ let export_functions_group (fmt : Format.formatter) (config : gen_config)
raise
(Failure "HOL4 doesn't have decreases/termination clauses")
in
- extract_decrease fwd.f;
- List.iter extract_decrease fwd.loops)
+ extract_decrease f.f;
+ List.iter extract_decrease f.loops)
pure_ls;
- (* Concatenate the function definitions, filtering the useless forward
- * functions. *)
+ (* Flatten the translated functions (concatenate the functions with
+ the declarations introduced for the loops) *)
let decls =
- List.concat
- (List.map
- (fun { keep_fwd; fwd; backs } ->
- let fwd =
- if keep_fwd then List.append fwd.loops [ fwd.f ] else []
- in
- let backs : Pure.fun_decl list =
- List.concat
- (List.map
- (fun back -> List.append back.loops [ back.f ])
- backs)
- in
- List.append fwd backs)
- pure_ls)
+ List.concat (List.map (fun f -> List.append f.loops [ f.f ]) pure_ls)
in
(* Extract the function definitions *)
@@ -724,9 +643,7 @@ let export_functions_group (fmt : Format.formatter) (config : gen_config)
(* Insert unit tests if necessary *)
if config.test_trans_unit_functions then
List.iter
- (fun trans ->
- if trans.keep_fwd then
- Extract.extract_unit_test_if_unit_fun ctx fmt trans.fwd.f)
+ (fun trans -> Extract.extract_unit_test_if_unit_fun ctx fmt trans.f)
pure_ls
(** Export a trait declaration. *)
@@ -812,7 +729,7 @@ let extract_definitions (fmt : Format.formatter) (config : gen_config)
extract their type directly in the records we generate for
the trait declarations themselves, there is no point in having
separate type definitions) *)
- match pure_fun.fwd.f.Pure.kind with
+ match pure_fun.f.Pure.kind with
| TraitMethodDecl _ -> ()
| _ ->
(* Translate *)
@@ -1001,18 +918,18 @@ let translate_crate (filename : string) (dest_dir : string) (crate : crate) :
* whether we should generate a decrease clause or not. *)
let rec_functions =
List.map
- (fun { fwd; _ } ->
- let fwd_f =
- if fwd.f.Pure.signature.fwd_info.effect_info.is_rec then
- [ (fwd.f.def_id, None) ]
+ (fun trans ->
+ let f =
+ if trans.f.Pure.signature.fwd_info.effect_info.is_rec then
+ [ (trans.f.def_id, None) ]
else []
in
- let loop_fwds =
+ let loops =
List.map
(fun (def : Pure.fun_decl) -> [ (def.def_id, def.loop_id) ])
- fwd.loops
+ trans.loops
in
- fwd_f :: loop_fwds)
+ f :: loops)
trans_funs
in
let rec_functions : PureUtils.fun_loop_id list =
@@ -1028,7 +945,7 @@ let translate_crate (filename : string) (dest_dir : string) (crate : crate) :
let trans_funs : pure_fun_translation FunDeclId.Map.t =
FunDeclId.Map.of_list
(List.map
- (fun (trans : pure_fun_translation) -> (trans.fwd.f.def_id, trans))
+ (fun (trans : pure_fun_translation) -> (trans.f.def_id, trans))
trans_funs)
in
@@ -1052,7 +969,6 @@ let translate_crate (filename : string) (dest_dir : string) (crate : crate) :
names_maps;
indent_incr = 2;
use_dep_ite = !Config.backend = Lean && !Config.extract_decreases_clauses;
- fun_name_info = PureUtils.RegularFunIdMap.empty;
trait_decl_id = None (* None by default *);
is_provided_method = false (* false by default *);
trans_trait_decls;
@@ -1082,7 +998,6 @@ let translate_crate (filename : string) (dest_dir : string) (crate : crate) :
(fun ctx (trans : pure_fun_translation) ->
(* If requested by the user, register termination measures and decreases
proofs for all the recursive functions *)
- let fwd_def = trans.fwd.f in
let gen_decr_clause (def : Pure.fun_decl) =
!Config.extract_decreases_clauses
&& PureUtils.FunLoopIdSet.mem
@@ -1091,7 +1006,7 @@ let translate_crate (filename : string) (dest_dir : string) (crate : crate) :
in
(* Register the names, only if the function is not a global body -
* those are handled later *)
- let is_global = fwd_def.Pure.is_global_decl_body in
+ let is_global = trans.f.Pure.is_global_decl_body in
if is_global then ctx
else Extract.extract_fun_decl_register_names ctx gen_decr_clause trans)
ctx
@@ -1171,13 +1086,7 @@ let translate_crate (filename : string) (dest_dir : string) (crate : crate) :
let exe_dir = Filename.dirname Sys.argv.(0) in
let primitives_src_dest =
match !Config.backend with
- | FStar ->
- let src =
- if !Config.return_back_funs then
- "/backends/fstar/merge/Primitives.fst"
- else "/backends/fstar/split/Primitives.fst"
- in
- Some (src, "Primitives.fst")
+ | FStar -> Some ("/backends/fstar/merge/Primitives.fst", "Primitives.fst")
| Coq -> Some ("/backends/coq/Primitives.v", "Primitives.v")
| Lean -> None
| HOL4 -> None
diff --git a/compiler/TranslateCore.ml b/compiler/TranslateCore.ml
index 88438872..05877b5a 100644
--- a/compiler/TranslateCore.ml
+++ b/compiler/TranslateCore.ml
@@ -8,19 +8,8 @@ let log = Logging.translate_log
type trans_ctx = decls_ctx [@@deriving show]
type fun_and_loops = { f : Pure.fun_decl; loops : Pure.fun_decl list }
-type pure_fun_translation_no_loops = Pure.fun_decl * Pure.fun_decl list
-
-type pure_fun_translation = {
- keep_fwd : bool;
- (** Should we extract the forward function?
-
- If the forward function returns `()` and there is exactly one
- backward function, we may merge the forward into the backward
- function and thus don't extract the forward function)?
- *)
- fwd : fun_and_loops;
- backs : fun_and_loops list;
-}
+type pure_fun_translation_no_loops = Pure.fun_decl
+type pure_fun_translation = fun_and_loops
let trans_ctx_to_fmt_env (ctx : trans_ctx) : Print.fmt_env =
Print.Contexts.decls_ctx_to_fmt_env ctx