diff options
author | Son Ho | 2024-03-08 16:06:35 +0100 |
---|---|---|
committer | Son Ho | 2024-03-08 16:06:35 +0100 |
commit | fe2a2cb34148e46e32cdcfbf100e38d9986082cd (patch) | |
tree | a378134d495985718b842a786bad573695974b95 | |
parent | bc154dda94c44b3ae67a3b04d3866cc473aead32 (diff) |
Make progress on propagating the changes
-rw-r--r-- | compiler/Config.ml | 44 | ||||
-rw-r--r-- | compiler/Contexts.ml | 2 | ||||
-rw-r--r-- | compiler/Extract.ml | 444 | ||||
-rw-r--r-- | compiler/ExtractBase.ml | 161 | ||||
-rw-r--r-- | compiler/ExtractBuiltin.ml | 11 | ||||
-rw-r--r-- | compiler/ExtractTypes.ml | 2 | ||||
-rw-r--r-- | compiler/Main.ml | 6 | ||||
-rw-r--r-- | compiler/PrintPure.ml | 18 | ||||
-rw-r--r-- | compiler/Pure.ml | 4 | ||||
-rw-r--r-- | compiler/PureMicroPasses.ml | 331 | ||||
-rw-r--r-- | compiler/ReorderDecls.ml | 12 | ||||
-rw-r--r-- | compiler/SymbolicToPure.ml | 13 | ||||
-rw-r--r-- | compiler/Translate.ml | 145 | ||||
-rw-r--r-- | compiler/TranslateCore.ml | 15 |
14 files changed, 259 insertions, 949 deletions
diff --git a/compiler/Config.ml b/compiler/Config.ml index 6fd866e8..af0e62d1 100644 --- a/compiler/Config.ml +++ b/compiler/Config.ml @@ -263,50 +263,6 @@ let decompose_nested_let_patterns = ref false *) let unfold_monadic_let_bindings = ref false -(** Controls whether we try to filter the calls to monadic functions - (which can fail) when their outputs are not used. - - The useless calls are calls to backward functions which have no outputs. - This case happens if the original Rust function only takes *shared* borrows - as inputs, and is thus pretty common. - - We are allowed to do this only because in this specific case, - the backward function fails *exactly* when the forward function fails - (they actually do exactly the same thing, the only difference being - that the forward function can potentially return a value), and upon - reaching the place where we should introduce a call to the backward - function, we know we have introduced a call to the forward function. - - Also note that in general, backward functions "do more things" than - forward functions, and have more opportunities to fail (even though - in the generated code, calls to the backward functions should fail - exactly when the corresponding, previous call to the forward functions - failed). - - This optimization is done in {!SymbolicToPure}. We might want to move it to - the micro-passes subsequent to the translation from symbolic to pure, but it - is really super easy to do it when going from symbolic to pure. Note that - we later filter the useless *forward* calls in the micro-passes, where it is - more natural to do. - - See the comments for {!PureMicroPasses.expression_contains_child_call_in_all_paths} - for additional explanations. - *) -let filter_useless_monadic_calls = ref true - -(** If {!filter_useless_monadic_calls} is activated, some functions - become useless: if this option is true, we don't extract them. - - The calls to functions which always get filtered are: - - the forward functions with unit return value - - the backward functions which don't output anything (backward - functions coming from rust functions with no mutable borrows - as input values - note that if a function doesn't take mutable - borrows as inputs, it can't return mutable borrows; we actually - dynamically check for that). - *) -let filter_useless_functions = ref true - (** Simplify the forward/backward functions, in case we merge them (i.e., the forward functions return the backward functions). diff --git a/compiler/Contexts.ml b/compiler/Contexts.ml index b1dd9553..54411fd5 100644 --- a/compiler/Contexts.ml +++ b/compiler/Contexts.ml @@ -109,6 +109,8 @@ let reset_global_counters () = region_id_counter := RegionId.generator_zero; abstraction_id_counter := AbstractionId.generator_zero; loop_id_counter := LoopId.generator_zero; + (* We want the loop id to start at 1 *) + let _ = fresh_loop_id () in fun_call_id_counter := FunCallId.generator_zero; dummy_var_id_counter := DummyVarId.generator_zero diff --git a/compiler/Extract.ml b/compiler/Extract.ml index dbca4f8f..794a1bfa 100644 --- a/compiler/Extract.ml +++ b/compiler/Extract.ml @@ -9,8 +9,7 @@ open TranslateCore open Config include ExtractTypes -(** Compute the names for all the pure functions generated from a rust function - (forward function and backward functions). +(** Compute the names for all the pure functions generated from a rust function. *) let extract_fun_decl_register_names (ctx : extraction_ctx) (has_decreases_clause : fun_decl -> bool) (def : pure_fun_translation) : @@ -19,63 +18,36 @@ let extract_fun_decl_register_names (ctx : extraction_ctx) method implementations): we do not need to refer to them directly. We will only use their type for the fields of the records we generate for the trait declarations *) - match def.fwd.f.kind with + match def.f.kind with | TraitMethodDecl _ -> ctx | _ -> ( (* Check if the function is builtin *) let builtin = let open ExtractBuiltin in let funs_map = builtin_funs_map () in - match_name_find_opt ctx.trans_ctx def.fwd.f.llbc_name funs_map + match_name_find_opt ctx.trans_ctx def.f.llbc_name funs_map in (* Use the builtin names if necessary *) match builtin with - | Some (filter_info, info) -> - (* Register the filtering information, if there is *) + | Some (filter_info, fun_info) -> + (* Builtin function: register the filtering information, if there is *) let ctx = match filter_info with | Some keep -> { ctx with funs_filter_type_args_map = - FunDeclId.Map.add def.fwd.f.def_id keep + FunDeclId.Map.add def.f.def_id keep ctx.funs_filter_type_args_map; } | _ -> ctx in - let funs = - if !Config.return_back_funs then [ def.fwd.f ] - else - let backs = List.map (fun f -> f.f) def.backs in - if def.keep_fwd then def.fwd.f :: backs else backs - in - List.fold_left - (fun ctx (f : fun_decl) -> - let open ExtractBuiltin in - let fun_id = - (Pure.FunId (FRegular f.def_id), f.loop_id, f.back_id) - in - let fun_info = - List.find_opt - (fun (x : builtin_fun_info) -> x.rg = f.back_id) - info - in - match fun_info with - | Some fun_info -> - ctx_add (FunId (FromLlbc fun_id)) fun_info.extract_name ctx - | None -> - raise - (Failure - ("Not found: " - ^ name_to_string ctx f.llbc_name - ^ ", " - ^ Print.option_to_string Pure.show_loop_id f.loop_id - ^ Print.option_to_string Pure.show_region_group_id - f.back_id))) - ctx funs + let f = def.f in + let open ExtractBuiltin in + let fun_id = (Pure.FunId (FRegular f.def_id), f.loop_id) in + ctx_add (FunId (FromLlbc fun_id)) fun_info.extract_name ctx | None -> - let fwd = def.fwd in - let backs = def.backs in + (* Not builtin *) (* Register the decrease clauses, if necessary *) let register_decreases ctx def = if has_decreases_clause def then @@ -88,21 +60,15 @@ let extract_fun_decl_register_names (ctx : extraction_ctx) | Lean -> ctx_add_decreases_proof def ctx else ctx in - let ctx = - List.fold_left register_decreases ctx (fwd.f :: fwd.loops) - in - let register_fun ctx f = ctx_add_fun_decl def f ctx in + (* We have to register the function itself, and the loops it + may contain (which are extracted as functions) *) + let funs = def.f :: def.loops in + (* Register the decrease clauses *) + let ctx = List.fold_left register_decreases ctx funs in + (* Register the name of the function and the loops *) + let register_fun ctx f = ctx_add_fun_decl f ctx in let register_funs ctx fl = List.fold_left register_fun ctx fl in - (* Register the names of the forward functions *) - let ctx = - if def.keep_fwd then register_funs ctx (fwd.f :: fwd.loops) else ctx - in - (* Register the names of the backward functions *) - List.fold_left - (fun ctx { f = back; loops = loop_backs } -> - let ctx = register_fun ctx back in - register_funs ctx loop_backs) - ctx backs) + register_funs ctx funs) (** Simply add the global name to the context. *) let extract_global_decl_register_names (ctx : extraction_ctx) @@ -230,7 +196,7 @@ let fun_builtin_filter_types (id : FunDeclId.id) (types : 'a list) let decl = FunDeclId.Map.find id ctx.trans_funs in let err = "Ill-formed builtin information for function " - ^ name_to_string ctx decl.fwd.f.llbc_name + ^ name_to_string ctx decl.f.llbc_name ^ ": " ^ string_of_int (List.length filter) ^ " filtering arguments provided for " @@ -460,8 +426,7 @@ and extract_function_call (ctx : extraction_ctx) (fmt : F.formatter) ]} *) (match fun_id with - | FromLlbc - (TraitMethod (trait_ref, method_name, _fun_decl_id), lp_id, rg_id) -> + | FromLlbc (TraitMethod (trait_ref, method_name, _fun_decl_id), lp_id) -> (* We have to check whether the trait method is required or provided *) let trait_decl_id = trait_ref.trait_decl_ref.trait_decl_id in let trait_decl = @@ -477,7 +442,7 @@ and extract_function_call (ctx : extraction_ctx) (fmt : F.formatter) extract_trait_ref ctx fmt TypeDeclId.Set.empty true trait_ref; let fun_name = ctx_get_trait_method trait_ref.trait_decl_ref.trait_decl_id - method_name rg_id ctx + method_name ctx in let add_brackets (s : string) = if !backend = Coq then "(" ^ s ^ ")" else s @@ -486,9 +451,7 @@ and extract_function_call (ctx : extraction_ctx) (fmt : F.formatter) else (* Provided method: we see it as a regular function call, and use the function name *) - let fun_id = - FromLlbc (FunId (FRegular method_id.id), lp_id, rg_id) - in + let fun_id = FromLlbc (FunId (FRegular method_id.id), lp_id) in let fun_name = ctx_get_function fun_id ctx in F.pp_print_string fmt fun_name; @@ -513,7 +476,7 @@ and extract_function_call (ctx : extraction_ctx) (fmt : F.formatter) *) let types = match fun_id with - | FromLlbc (FunId (FRegular id), _, _) -> + | FromLlbc (FunId (FRegular id), _) -> fun_builtin_filter_types id generics.types ctx | _ -> Result.Ok generics.types in @@ -1392,11 +1355,6 @@ let extract_template_lean_termination_and_decreasing (ctx : extraction_ctx) let extract_fun_comment (ctx : extraction_ctx) (fmt : F.formatter) (def : fun_decl) : unit = - let { keep_fwd; num_backs } = - PureUtils.RegularFunIdMap.find - (Pure.FunId (FRegular def.def_id), def.loop_id, def.back_id) - ctx.fun_name_info - in let comment_pre = "[" ^ name_to_string ctx def.llbc_name ^ "]:" in let comment = let loop_comment = @@ -1404,23 +1362,7 @@ let extract_fun_comment (ctx : extraction_ctx) (fmt : F.formatter) | None -> "" | Some id -> " loop " ^ LoopId.to_string id ^ ":" in - let fwd_back_comment = - match def.back_id with - | None -> if !Config.return_back_funs then [] else [ "forward function" ] - | Some id -> - (* Check if there is only one backward function, and no forward function *) - if (not keep_fwd) && num_backs = 1 then - [ - "merged forward/backward function"; - "(there is a single backward function, and the forward function \ - returns ())"; - ] - else [ "backward function " ^ T.RegionGroupId.to_string id ] - in - match fwd_back_comment with - | [] -> [ comment_pre ^ loop_comment ] - | [ s ] -> [ comment_pre ^ loop_comment ^ " " ^ s ] - | s :: sl -> (comment_pre ^ loop_comment ^ " " ^ s) :: sl + [ comment_pre ^ loop_comment ] in extract_comment_with_span fmt comment def.meta.span @@ -1435,9 +1377,7 @@ let extract_fun_decl_gen (ctx : extraction_ctx) (fmt : F.formatter) (kind : decl_kind) (has_decreases_clause : bool) (def : fun_decl) : unit = assert (not def.is_global_decl_body); (* Retrieve the function name *) - let def_name = - ctx_get_local_function def.def_id def.loop_id def.back_id ctx - in + let def_name = ctx_get_local_function def.def_id def.loop_id ctx in (* Add a break before *) if !backend <> HOL4 || not (decl_is_first_from_group kind) then F.pp_print_break fmt 0 0; @@ -1681,9 +1621,7 @@ let extract_fun_decl_gen (ctx : extraction_ctx) (fmt : F.formatter) let extract_fun_decl_hol4_opaque (ctx : extraction_ctx) (fmt : F.formatter) (def : fun_decl) : unit = (* Retrieve the definition name *) - let def_name = - ctx_get_local_function def.def_id def.loop_id def.back_id ctx - in + let def_name = ctx_get_local_function def.def_id def.loop_id ctx in assert (def.signature.generics.const_generics = []); (* Add the type/const gen parameters - note that we need those bindings only for the generation of the type (they are not top-level) *) @@ -1870,7 +1808,6 @@ let extract_global_decl_hol4_opaque (ctx : extraction_ctx) (fmt : F.formatter) let extract_global_decl (ctx : extraction_ctx) (fmt : F.formatter) (global : A.global_decl) (body : fun_decl) (interface : bool) : unit = assert body.is_global_decl_body; - assert (Option.is_none body.back_id); assert (body.signature.inputs = []); assert (body.signature.generics = empty_generic_params); @@ -1883,9 +1820,7 @@ let extract_global_decl (ctx : extraction_ctx) (fmt : F.formatter) let decl_name = ctx_get_global global.def_id ctx in let body_name = - ctx_get_function - (FromLlbc (Pure.FunId (FRegular global.body), None, None)) - ctx + ctx_get_function (FromLlbc (Pure.FunId (FRegular global.body), None)) ctx in let decl_ty, body_ty = @@ -2058,80 +1993,45 @@ let extract_trait_decl_method_names (ctx : extraction_ctx) let required_methods = trait_decl.required_methods in (* Compute the names *) let method_names = - (* We add one field per required forward/backward function *) - let get_funs_for_id (id : fun_decl_id) : fun_decl list = - let trans : pure_fun_translation = FunDeclId.Map.find id ctx.trans_funs in - if !Config.return_back_funs then [ trans.fwd.f ] - else List.map (fun f -> f.f) (trans.fwd :: trans.backs) - in match builtin_info with | None -> - (* We add one field per required forward/backward function *) - let compute_item_names (item_name : string) (id : fun_decl_id) : - string * (RegionGroupId.id option * string) list = - let compute_fun_name (f : fun_decl) : RegionGroupId.id option * string - = - (* We do something special to reuse the [ctx_compute_fun_decl] - function. TODO: make it cleaner. *) - let llbc_name : Types.name = - [ Types.PeIdent (item_name, Disambiguator.zero) ] - in - let f = { f with llbc_name } in - let trans = A.FunDeclId.Map.find f.def_id ctx.trans_funs in - let name = ctx_compute_fun_name trans f ctx in - (* Add a prefix if necessary *) - let name = - if !Config.record_fields_short_names then name - else ctx_compute_trait_decl_name ctx trait_decl ^ "_" ^ name - in - (f.back_id, name) + (* Not a builtin function *) + let compute_item_name (item_name : string) (id : fun_decl_id) : + string * string = + let trans : pure_fun_translation = + FunDeclId.Map.find id ctx.trans_funs + in + let f = trans.f in + (* We do something special to reuse the [ctx_compute_fun_decl] + function. TODO: make it cleaner. *) + let llbc_name : Types.name = + [ Types.PeIdent (item_name, Disambiguator.zero) ] + in + let f = { f with llbc_name } in + let name = ctx_compute_fun_name f ctx in + (* Add a prefix if necessary *) + let name = + if !Config.record_fields_short_names then name + else ctx_compute_trait_decl_name ctx trait_decl ^ "_" ^ name in - let funs = get_funs_for_id id in - (item_name, List.map compute_fun_name funs) + (item_name, name) in - List.map (fun (name, id) -> compute_item_names name id) required_methods + List.map (fun (name, id) -> compute_item_name name id) required_methods | Some info -> + (* This is a builtin *) let funs_map = StringMap.of_list info.methods in List.map - (fun (item_name, fun_id) -> + (fun (item_name, _) -> let open ExtractBuiltin in let info = StringMap.find item_name funs_map in - let trans_funs = get_funs_for_id fun_id in - let find (trans_fun : fun_decl) = - let info = - List.find_opt - (fun (info : builtin_fun_info) -> info.rg = trans_fun.back_id) - info - in - match info with - | Some info -> (info.rg, info.extract_name) - | None -> - let err = - "Ill-formed builtin information for trait decl \"" - ^ name_to_string ctx trait_decl.llbc_name - ^ "\", method \"" ^ item_name - ^ "\": could not find name for region " - ^ Print.option_to_string Pure.show_region_group_id - trans_fun.back_id - in - log#serror err; - if !Config.fail_hard then raise (Failure err) - else (trans_fun.back_id, "%ERROR_BUILTIN_NAME_NOT_FOUND%") - in - let rg_with_name_list = List.map find trans_funs in - (item_name, rg_with_name_list)) + let fun_name = info.extract_name in + (item_name, fun_name)) required_methods in (* Register the names *) List.fold_left - (fun ctx (item_name, funs) -> - (* We add one field per required forward/backward function *) - List.fold_left - (fun ctx (rg, fun_name) -> - ctx_add - (TraitMethodId (trait_decl.def_id, item_name, rg)) - fun_name ctx) - ctx funs) + (fun ctx (item_name, fun_name) -> + ctx_add (TraitMethodId (trait_decl.def_id, item_name)) fun_name ctx) ctx method_names (** Similar to {!extract_type_decl_register_names} *) @@ -2263,46 +2163,41 @@ let extract_trait_decl_method_items (ctx : extraction_ctx) (fmt : F.formatter) (* Lookup the definition *) let trans = A.FunDeclId.Map.find id ctx.trans_funs in (* Extract the items *) - let funs = if trans.keep_fwd then trans.fwd :: trans.backs else trans.backs in - let extract_method (f : fun_and_loops) = - let f = f.f in - let fun_name = ctx_get_trait_method decl.def_id item_name f.back_id ctx in - let ty () = - (* Extract the generics *) - (* We need to add the generics specific to the method, by removing those - which actually apply to the trait decl *) - let generics = - let drop_trait_clauses = false in - generic_params_drop_prefix ~drop_trait_clauses decl.generics - f.signature.generics - in - (* Note that we do not filter the LLBC generic parameters. - This is ok because: - - we only use them to find meaningful names for the trait clauses - - we only generate trait clauses for the clauses we find in the - pure generics *) - let ctx, type_params, cg_params, trait_clauses = - ctx_add_generic_params f.llbc_name f.signature.llbc_generics generics - ctx - in - let backend_uses_forall = - match !backend with Coq | Lean -> true | FStar | HOL4 -> false - in - let generics_not_empty = generics <> empty_generic_params in - let use_forall = generics_not_empty && backend_uses_forall in - let use_arrows = generics_not_empty && not backend_uses_forall in - let use_forall_use_sep = false in - extract_generic_params ctx fmt TypeDeclId.Set.empty ~use_forall - ~use_forall_use_sep ~use_arrows generics type_params cg_params - trait_clauses; - if use_forall then F.pp_print_string fmt ","; - (* Extract the inputs and output *) - F.pp_print_space fmt (); - extract_fun_inputs_output_parameters_types ctx fmt f + let f = trans.f in + let fun_name = ctx_get_trait_method decl.def_id item_name ctx in + let ty () = + (* Extract the generics *) + (* We need to add the generics specific to the method, by removing those + which actually apply to the trait decl *) + let generics = + let drop_trait_clauses = false in + generic_params_drop_prefix ~drop_trait_clauses decl.generics + f.signature.generics + in + (* Note that we do not filter the LLBC generic parameters. + This is ok because: + - we only use them to find meaningful names for the trait clauses + - we only generate trait clauses for the clauses we find in the + pure generics *) + let ctx, type_params, cg_params, trait_clauses = + ctx_add_generic_params f.llbc_name f.signature.llbc_generics generics ctx in - extract_trait_decl_item ctx fmt fun_name ty + let backend_uses_forall = + match !backend with Coq | Lean -> true | FStar | HOL4 -> false + in + let generics_not_empty = generics <> empty_generic_params in + let use_forall = generics_not_empty && backend_uses_forall in + let use_arrows = generics_not_empty && not backend_uses_forall in + let use_forall_use_sep = false in + extract_generic_params ctx fmt TypeDeclId.Set.empty ~use_forall + ~use_forall_use_sep ~use_arrows generics type_params cg_params + trait_clauses; + if use_forall then F.pp_print_string fmt ","; + (* Extract the inputs and output *) + F.pp_print_space fmt (); + extract_fun_inputs_output_parameters_types ctx fmt f in - List.iter extract_method funs + extract_trait_decl_item ctx fmt fun_name ty (** Extract a trait declaration *) let extract_trait_decl (ctx : extraction_ctx) (fmt : F.formatter) @@ -2494,21 +2389,10 @@ let extract_trait_decl_coq_arguments (ctx : extraction_ctx) (fmt : F.formatter) decl.parent_clauses; (* The required methods *) List.iter - (fun (item_name, id) -> - (* Lookup the definition *) - let trans = A.FunDeclId.Map.find id ctx.trans_funs in + (fun (item_name, _) -> (* Extract the items *) - let funs = - if trans.keep_fwd then trans.fwd :: trans.backs else trans.backs - in - let extract_for_method (f : fun_and_loops) = - let f = f.f in - let item_name = - ctx_get_trait_method decl.def_id item_name f.back_id ctx - in - extract_coq_arguments_instruction ctx fmt item_name num_params - in - List.iter extract_for_method funs) + let item_name = ctx_get_trait_method decl.def_id item_name ctx in + extract_coq_arguments_instruction ctx fmt item_name num_params) decl.required_methods; (* Add a space *) F.pp_print_space fmt ()) @@ -2531,75 +2415,71 @@ let extract_trait_impl_method_items (ctx : extraction_ctx) (fmt : F.formatter) (* Lookup the definition *) let trans = A.FunDeclId.Map.find id ctx.trans_funs in (* Extract the items *) - let funs = if trans.keep_fwd then trans.fwd :: trans.backs else trans.backs in - let extract_method (f : fun_and_loops) = - let f = f.f in - let fun_name = ctx_get_trait_method trait_decl_id item_name f.back_id ctx in - let ty () = - (* Filter the generics if the method is a builtin *) - let i_tys, _, _ = impl_generics in - let impl_types, i_tys, f_tys = - match FunDeclId.Map.find_opt f.def_id ctx.funs_filter_type_args_map with - | None -> (impl.generics.types, i_tys, f.signature.generics.types) - | Some filter -> - let filter_list filter ls = - let ls = List.combine filter ls in - List.filter_map (fun (b, ty) -> if b then Some ty else None) ls - in - let impl_types = impl.generics.types in - let impl_filter = - Collections.List.prefix (List.length impl_types) filter - in - let i_tys = i_tys in - let i_filter = Collections.List.prefix (List.length i_tys) filter in - ( filter_list impl_filter impl_types, - filter_list i_filter i_tys, - filter_list filter f.signature.generics.types ) - in - let f_generics = { f.signature.generics with types = f_tys } in - (* Extract the generics - we need to quantify over the generics which - are specific to the method, and call it will all the generics - (trait impl + method generics) *) - let f_generics = - let drop_trait_clauses = true in - generic_params_drop_prefix ~drop_trait_clauses - { impl.generics with types = impl_types } - f_generics - in - (* Register and print the quantified generics. - - Note that we do not filter the LLBC generic parameters. - This is ok because: - - we only use them to find meaningful names for the trait clauses - - we only generate trait clauses for the clauses we find in the - pure generics *) - let ctx, f_tys, f_cgs, f_tcs = - ctx_add_generic_params f.llbc_name f.signature.llbc_generics f_generics - ctx - in - let use_forall = f_generics <> empty_generic_params in - extract_generic_params ctx fmt TypeDeclId.Set.empty ~use_forall f_generics - f_tys f_cgs f_tcs; - if use_forall then F.pp_print_string fmt ","; - (* Extract the function call *) - F.pp_print_space fmt (); - let fun_name = ctx_get_local_function f.def_id None f.back_id ctx in - F.pp_print_string fmt fun_name; - let all_generics = - let _, i_cgs, i_tcs = impl_generics in - List.concat [ i_tys; f_tys; i_cgs; f_cgs; i_tcs; f_tcs ] - in - - (* Filter the generics if the function is builtin *) - List.iter - (fun p -> - F.pp_print_space fmt (); - F.pp_print_string fmt p) - all_generics + let f = trans.f in + let fun_name = ctx_get_trait_method trait_decl_id item_name ctx in + let ty () = + (* Filter the generics if the method is a builtin *) + let i_tys, _, _ = impl_generics in + let impl_types, i_tys, f_tys = + match FunDeclId.Map.find_opt f.def_id ctx.funs_filter_type_args_map with + | None -> (impl.generics.types, i_tys, f.signature.generics.types) + | Some filter -> + let filter_list filter ls = + let ls = List.combine filter ls in + List.filter_map (fun (b, ty) -> if b then Some ty else None) ls + in + let impl_types = impl.generics.types in + let impl_filter = + Collections.List.prefix (List.length impl_types) filter + in + let i_tys = i_tys in + let i_filter = Collections.List.prefix (List.length i_tys) filter in + ( filter_list impl_filter impl_types, + filter_list i_filter i_tys, + filter_list filter f.signature.generics.types ) + in + let f_generics = { f.signature.generics with types = f_tys } in + (* Extract the generics - we need to quantify over the generics which + are specific to the method, and call it will all the generics + (trait impl + method generics) *) + let f_generics = + let drop_trait_clauses = true in + generic_params_drop_prefix ~drop_trait_clauses + { impl.generics with types = impl_types } + f_generics + in + (* Register and print the quantified generics. + + Note that we do not filter the LLBC generic parameters. + This is ok because: + - we only use them to find meaningful names for the trait clauses + - we only generate trait clauses for the clauses we find in the + pure generics *) + let ctx, f_tys, f_cgs, f_tcs = + ctx_add_generic_params f.llbc_name f.signature.llbc_generics f_generics + ctx + in + let use_forall = f_generics <> empty_generic_params in + extract_generic_params ctx fmt TypeDeclId.Set.empty ~use_forall f_generics + f_tys f_cgs f_tcs; + if use_forall then F.pp_print_string fmt ","; + (* Extract the function call *) + F.pp_print_space fmt (); + let fun_name = ctx_get_local_function f.def_id None ctx in + F.pp_print_string fmt fun_name; + let all_generics = + let _, i_cgs, i_tcs = impl_generics in + List.concat [ i_tys; f_tys; i_cgs; f_cgs; i_tcs; f_tcs ] in - extract_trait_impl_item ctx fmt fun_name ty + + (* Filter the generics if the function is builtin *) + List.iter + (fun p -> + F.pp_print_space fmt (); + F.pp_print_string fmt p) + all_generics in - List.iter extract_method funs + extract_trait_impl_item ctx fmt fun_name ty (** Extract a trait implementation *) let extract_trait_impl (ctx : extraction_ctx) (fmt : F.formatter) @@ -2766,8 +2646,6 @@ let extract_trait_impl (ctx : extraction_ctx) (fmt : F.formatter) *) let extract_unit_test_if_unit_fun (ctx : extraction_ctx) (fmt : F.formatter) (def : fun_decl) : unit = - (* We only insert unit tests for forward functions *) - assert (def.back_id = None); (* Check if this is a unit function *) let sg = def.signature in if @@ -2791,9 +2669,7 @@ let extract_unit_test_if_unit_fun (ctx : extraction_ctx) (fmt : F.formatter) F.pp_print_string fmt "assert_norm"; F.pp_print_space fmt (); F.pp_print_string fmt "("; - let fun_name = - ctx_get_local_function def.def_id def.loop_id def.back_id ctx - in + let fun_name = ctx_get_local_function def.def_id def.loop_id ctx in F.pp_print_string fmt fun_name; if sg.inputs <> [] then ( F.pp_print_space fmt (); @@ -2807,9 +2683,7 @@ let extract_unit_test_if_unit_fun (ctx : extraction_ctx) (fmt : F.formatter) F.pp_print_string fmt "Check"; F.pp_print_space fmt (); F.pp_print_string fmt "("; - let fun_name = - ctx_get_local_function def.def_id def.loop_id def.back_id ctx - in + let fun_name = ctx_get_local_function def.def_id def.loop_id ctx in F.pp_print_string fmt fun_name; if sg.inputs <> [] then ( F.pp_print_space fmt (); @@ -2820,9 +2694,7 @@ let extract_unit_test_if_unit_fun (ctx : extraction_ctx) (fmt : F.formatter) F.pp_print_string fmt "#assert"; F.pp_print_space fmt (); F.pp_print_string fmt "("; - let fun_name = - ctx_get_local_function def.def_id def.loop_id def.back_id ctx - in + let fun_name = ctx_get_local_function def.def_id def.loop_id ctx in F.pp_print_string fmt fun_name; if sg.inputs <> [] then ( F.pp_print_space fmt (); @@ -2835,9 +2707,7 @@ let extract_unit_test_if_unit_fun (ctx : extraction_ctx) (fmt : F.formatter) | HOL4 -> F.pp_print_string fmt "val _ = assert_return ("; F.pp_print_string fmt "“"; - let fun_name = - ctx_get_local_function def.def_id def.loop_id def.back_id ctx - in + let fun_name = ctx_get_local_function def.def_id def.loop_id ctx in F.pp_print_string fmt fun_name; if sg.inputs <> [] then ( F.pp_print_space fmt (); diff --git a/compiler/ExtractBase.ml b/compiler/ExtractBase.ml index 04686705..591e8aab 100644 --- a/compiler/ExtractBase.ml +++ b/compiler/ExtractBase.ml @@ -167,11 +167,7 @@ type id = | TraitImplId of TraitImplId.id | LocalTraitClauseId of TraitClauseId.id | TraitDeclConstructorId of TraitDeclId.id - | TraitMethodId of TraitDeclId.id * string * T.RegionGroupId.id option - (** Something peculiar with trait methods: because we have to take into - account forward/backward functions, we may need to generate fields - items per method. - *) + | TraitMethodId of TraitDeclId.id * string | TraitItemId of TraitDeclId.id * string (** A trait associated item which is not a method *) | TraitParentClauseId of TraitDeclId.id * TraitClauseId.id @@ -353,8 +349,6 @@ let basename_to_unique (names_set : StringSet.t) in if StringSet.mem basename names_set then gen 1 else basename -type fun_name_info = { keep_fwd : bool; num_backs : int } - type names_maps = { names_map : names_map; (** The map for id to names, where we forbid name collisions @@ -384,7 +378,7 @@ let allow_collisions (id : id) : bool = | FieldId _ | TraitItemClauseId _ | TraitParentClauseId _ | TraitItemId _ | TraitMethodId _ -> !Config.record_fields_short_names - | FunId (Pure _ | FromLlbc (FunId (FAssumed _), _, _)) -> + | FunId (Pure _ | FromLlbc (FunId (FAssumed _), _)) -> (* We map several assumed functions to the same id *) true | _ -> false @@ -549,15 +543,6 @@ type extraction_ctx = { -- makes the if then else dependent ]} *) - fun_name_info : fun_name_info PureUtils.RegularFunIdMap.t; - (** Information used to filter and name functions - we use it - to print comments in the generated code, to help link - the generated code to the original code (information such - as: "this function is the backward function of ...", or - "this function is the merged forward/backward function of ..." - in case a Rust function only has one backward translation - and we filter the forward function because it returns unit. - *) trait_decl_id : trait_decl_id option; (** If we are extracting a trait declaration, identifies it *) is_provided_method : bool; @@ -668,14 +653,8 @@ let id_to_string (id : id) (ctx : extraction_ctx) : string = ^ TraitClauseId.to_string clause_id | TraitItemId (id, name) -> "trait_item_id: " ^ trait_decl_id_to_string id ^ ", type name: " ^ name - | TraitMethodId (trait_decl_id, fun_name, rg_id) -> - let fwd_back_kind = - match rg_id with - | None -> "forward" - | Some rg_id -> "backward " ^ RegionGroupId.to_string rg_id - in - trait_decl_id_to_string trait_decl_id - ^ ", method name (" ^ fwd_back_kind ^ "): " ^ fun_name + | TraitMethodId (trait_decl_id, fun_name) -> + trait_decl_id_to_string trait_decl_id ^ ", method name: " ^ fun_name | TraitSelfClauseId -> "trait_self_clause" let ctx_add (id : id) (name : string) (ctx : extraction_ctx) : extraction_ctx = @@ -694,8 +673,8 @@ let ctx_get_function (id : fun_id) (ctx : extraction_ctx) : string = ctx_get (FunId id) ctx let ctx_get_local_function (id : A.FunDeclId.id) (lp : LoopId.id option) - (rg : RegionGroupId.id option) (ctx : extraction_ctx) : string = - ctx_get_function (FromLlbc (FunId (FRegular id), lp, rg)) ctx + (ctx : extraction_ctx) : string = + ctx_get_function (FromLlbc (FunId (FRegular id), lp)) ctx let ctx_get_type (id : type_id) (ctx : extraction_ctx) : string = assert (id <> TTuple); @@ -733,8 +712,8 @@ let ctx_get_trait_type (id : trait_decl_id) (item_name : string) ctx_get_trait_item id item_name ctx let ctx_get_trait_method (id : trait_decl_id) (item_name : string) - (rg_id : T.RegionGroupId.id option) (ctx : extraction_ctx) : string = - ctx_get (TraitMethodId (id, item_name, rg_id)) ctx + (ctx : extraction_ctx) : string = + ctx_get (TraitMethodId (id, item_name)) ctx let ctx_get_trait_parent_clause (id : trait_decl_id) (clause : trait_clause_id) (ctx : extraction_ctx) : string = @@ -1164,8 +1143,7 @@ let initialize_names_maps () : names_maps = in let assumed_functions = List.map - (fun (fid, rg, name) -> - (FromLlbc (Pure.FunId (FAssumed fid), None, rg), name)) + (fun (fid, name) -> (FromLlbc (Pure.FunId (FAssumed fid), None), name)) init.assumed_llbc_functions @ List.map (fun (fid, name) -> (Pure fid, name)) init.assumed_pure_functions in @@ -1408,61 +1386,12 @@ let default_fun_loop_suffix (num_loops : int) (loop_id : LoopId.id option) : If this function admits only one loop, we omit it. *) if num_loops = 1 then "_loop" else "_loop" ^ LoopId.to_string loop_id -(** A helper function: generates a function suffix from a region group - information. +(** A helper function: generates a function suffix. TODO: move all those helpers. *) -let default_fun_suffix (num_loops : int) (loop_id : LoopId.id option) - (num_region_groups : int) (rg : region_group_info option) - ((keep_fwd, num_backs) : bool * int) : string = - let lp_suff = default_fun_loop_suffix num_loops loop_id in - - (* There are several cases: - - [rg] is [Some]: this is a forward function: - - we add "_fwd" - - [rg] is [None]: this is a backward function: - - this function has one extracted backward function: - - if the forward function has been filtered, we add nothing: - the forward function is useless, so the unique backward function - takes its place, in a way (in effect, we "merge" the forward - and the backward functions). - - otherwise we add "_back" - - this function has several backward functions: we add "_back" and an - additional suffix to identify the precise backward function - Note that we always add a suffix (in case there are no region groups, - we could not add the "_fwd" suffix) to prevent name clashes between - definitions (in particular between type and function definitions). - *) - let rg_suff = - (* TODO: make all the backends match what is done for Lean *) - match rg with - | None -> - if - (* In order to avoid name conflicts: - * - if the forward is eliminated, we add the suffix "_fwd" (it won't be used) - * - otherwise, no suffix (because the backward functions will have a suffix) - *) - num_backs = 1 && not keep_fwd - then "_fwd" - else "" - | Some rg -> - assert (num_region_groups > 0 && num_backs > 0); - if num_backs = 1 then - (* Exactly one backward function *) - if not keep_fwd then "" else "_back" - else if - (* Several region groups/backward functions: - - if all the regions in the group have names, we use those names - - otherwise we use an index - *) - List.for_all Option.is_some rg.region_names - then - (* Concatenate the region names *) - "_back" ^ String.concat "" (List.map Option.get rg.region_names) - else (* Use the region index *) - "_back" ^ RegionGroupId.to_string rg.id - in - lp_suff ^ rg_suff +let default_fun_suffix (num_loops : int) (loop_id : LoopId.id option) : string = + (* We only generate a suffix for the functions we generate from the loops *) + default_fun_loop_suffix num_loops loop_id (** Compute the name of a regular (non-assumed) function. @@ -1472,24 +1401,13 @@ let default_fun_suffix (num_loops : int) (loop_id : LoopId.id option) indices to derive unique names for the loops for instance - if there is exactly one loop, we don't need to use indices) - loop id (if pertinent) - - number of region groups - - region group information in case of a backward function - ([None] if forward function) - - pair: - - do we generate the forward function (it may have been filtered)? - - the number of *extracted backward functions* (same comment as for - the number of loops) - The number of extracted backward functions if not necessarily - equal to the number of region groups, because we may have - filtered some of them. TODO: use the fun id for the assumed functions. *) let ctx_compute_fun_name (ctx : extraction_ctx) (fname : llbc_name) - (num_loops : int) (loop_id : LoopId.id option) (num_rgs : int) - (rg : region_group_info option) (filter_info : bool * int) : string = + (num_loops : int) (loop_id : LoopId.id option) : string = let fname = ctx_compute_fun_name_no_suffix ctx fname in (* Compute the suffix *) - let suffix = default_fun_suffix num_loops loop_id num_rgs rg filter_info in + let suffix = default_fun_suffix num_loops loop_id in (* Concatenate *) fname ^ suffix @@ -1963,61 +1881,26 @@ let ctx_add_global_decl_and_body (def : A.global_decl) (ctx : extraction_ctx) : | None -> (* Not the case: "standard" registration *) let name = ctx_compute_global_name ctx def.name in - let body = FunId (FromLlbc (FunId (FRegular def.body), None, None)) in + let body = FunId (FromLlbc (FunId (FRegular def.body), None)) in let ctx = ctx_add decl (name ^ "_c") ctx in let ctx = ctx_add body (name ^ "_body") ctx in ctx -let ctx_compute_fun_name (trans_group : pure_fun_translation) (def : fun_decl) - (ctx : extraction_ctx) : string = - (* Lookup the LLBC def to compute the region group information *) - let def_id = def.def_id in - let llbc_def = A.FunDeclId.Map.find def_id ctx.trans_ctx.fun_ctx.fun_decls in - let sg = llbc_def.signature in - let regions_hierarchy = - LlbcAstUtils.FunIdMap.find (FRegular def_id) - ctx.trans_ctx.fun_ctx.regions_hierarchies - in - let num_rgs = List.length regions_hierarchy in - let { keep_fwd; fwd = _; backs } = trans_group in - let num_backs = List.length backs in - let rg_info = - match def.back_id with - | None -> None - | Some rg_id -> - let rg = T.RegionGroupId.nth regions_hierarchy rg_id in - let region_names = - List.map - (fun rid -> (T.RegionVarId.nth sg.generics.regions rid).name) - rg.regions - in - Some { id = rg_id; region_names } - in +let ctx_compute_fun_name (def : fun_decl) (ctx : extraction_ctx) : string = (* Add the function name *) - ctx_compute_fun_name ctx def.llbc_name def.num_loops def.loop_id num_rgs - rg_info (keep_fwd, num_backs) + ctx_compute_fun_name ctx def.llbc_name def.num_loops def.loop_id (* TODO: move to Extract *) -let ctx_add_fun_decl (trans_group : pure_fun_translation) (def : fun_decl) - (ctx : extraction_ctx) : extraction_ctx = +let ctx_add_fun_decl (def : fun_decl) (ctx : extraction_ctx) : extraction_ctx = (* Sanity check: the function should not be a global body - those are handled * separately *) assert (not def.is_global_decl_body); (* Lookup the LLBC def to compute the region group information *) let def_id = def.def_id in - let { keep_fwd; fwd = _; backs } = trans_group in - let num_backs = List.length backs in (* Add the function name *) - let def_name = ctx_compute_fun_name trans_group def ctx in - let fun_id = (Pure.FunId (FRegular def_id), def.loop_id, def.back_id) in - let ctx = ctx_add (FunId (FromLlbc fun_id)) def_name ctx in - (* Add the name info *) - { - ctx with - fun_name_info = - PureUtils.RegularFunIdMap.add fun_id { keep_fwd; num_backs } - ctx.fun_name_info; - } + let def_name = ctx_compute_fun_name def ctx in + let fun_id = (Pure.FunId (FRegular def_id), def.loop_id) in + ctx_add (FunId (FromLlbc fun_id)) def_name ctx let ctx_compute_type_decl_name (ctx : extraction_ctx) (def : type_decl) : string = diff --git a/compiler/ExtractBuiltin.ml b/compiler/ExtractBuiltin.ml index 3ea5655a..88de31fe 100644 --- a/compiler/ExtractBuiltin.ml +++ b/compiler/ExtractBuiltin.ml @@ -221,12 +221,11 @@ type builtin_fun_info = { extract_name : string } [@@deriving show] parameters. For instance, in the case of the `Vec` functions, there is a type parameter for the allocator to use, which we want to filter. *) -let builtin_funs () : (pattern * bool list option * builtin_fun_info list) list - = +let builtin_funs () : (pattern * bool list option * builtin_fun_info) list = (* Small utility *) let mk_fun (rust_name : string) (extract_name : string option) (filter : bool list option) : - pattern * bool list option * builtin_fun_info list = + pattern * bool list option * builtin_fun_info = let rust_name = try parse_pattern rust_name with Failure _ -> @@ -238,7 +237,7 @@ let builtin_funs () : (pattern * bool list option * builtin_fun_info list) list | Some name -> split_on_separator name in let basename = flatten_name extract_name in - let f = [ { extract_name = basename } ] in + let f = { extract_name = basename } in (rust_name, filter, f) in [ @@ -377,7 +376,7 @@ type builtin_trait_decl_info = { - a Rust name - an extraction name - a list of clauses *) - methods : (string * builtin_fun_info list) list; + methods : (string * builtin_fun_info) list; } [@@deriving show] @@ -418,7 +417,7 @@ let builtin_trait_decls_info () = if !record_fields_short_names then item_name else extract_name ^ "_" ^ item_name in - let fwd = [ { extract_name = basename } ] in + let fwd = { extract_name = basename } in (item_name, fwd) in List.map mk_method methods diff --git a/compiler/ExtractTypes.ml b/compiler/ExtractTypes.ml index a3dbf3cc..05b71b9f 100644 --- a/compiler/ExtractTypes.ml +++ b/compiler/ExtractTypes.ml @@ -272,7 +272,7 @@ let start_fun_decl_group (ctx : extraction_ctx) (fmt : F.formatter) if is_single_opaque_fun_decl_group dg then () else let compute_fun_def_name (def : Pure.fun_decl) : string = - ctx_get_local_function def.def_id def.loop_id def.back_id ctx ^ "_def" + ctx_get_local_function def.def_id def.loop_id ctx ^ "_def" in let names = List.map compute_fun_def_name dg in (* Add a break before *) diff --git a/compiler/Main.ml b/compiler/Main.ml index 664ec067..3f5e62ad 100644 --- a/compiler/Main.ml +++ b/compiler/Main.ml @@ -72,12 +72,6 @@ let () = Arg.Symbol (backend_names, set_backend), " Specify the target backend" ); ("-dest", Arg.Set_string dest_dir, " Specify the output directory"); - ( "-no-filter-useless-calls", - Arg.Clear filter_useless_monadic_calls, - " Do not filter the useless function calls" ); - ( "-no-filter-useless-funs", - Arg.Clear filter_useless_functions, - " Do not filter the useless forward/backward functions" ); ( "-test-units", Arg.Set test_unit_functions, " Test the unit functions with the concrete (i.e., not symbolic) \ diff --git a/compiler/PrintPure.ml b/compiler/PrintPure.ml index 66475d02..21ca7f08 100644 --- a/compiler/PrintPure.ml +++ b/compiler/PrintPure.ml @@ -462,21 +462,13 @@ let inst_fun_sig_to_string (env : fmt_env) (sg : inst_fun_sig) : string = let all_types = List.append inputs [ output ] in String.concat " -> " all_types -let fun_suffix (lp_id : LoopId.id option) (rg_id : T.RegionGroupId.id option) : - string = +let fun_suffix (lp_id : LoopId.id option) : string = let lp_suff = match lp_id with | None -> "" | Some lp_id -> "^loop" ^ LoopId.to_string lp_id in - - let rg_suff = - match rg_id with - | None -> "" - | Some rg_id -> "@" ^ T.RegionGroupId.to_string rg_id - in - - lp_suff ^ rg_suff + lp_suff let llbc_assumed_fun_id_to_string (fid : A.assumed_fun_id) : string = match fid with @@ -505,7 +497,7 @@ let pure_assumed_fun_id_to_string (fid : pure_assumed_fun_id) : string = let regular_fun_id_to_string (env : fmt_env) (fun_id : fun_id) : string = match fun_id with - | FromLlbc (fid, lp_id, rg_id) -> + | FromLlbc (fid, lp_id) -> let f = match fid with | FunId (FRegular fid) -> fun_decl_id_to_string env fid @@ -513,7 +505,7 @@ let regular_fun_id_to_string (env : fmt_env) (fun_id : fun_id) : string = | TraitMethod (trait_ref, method_name, _) -> trait_ref_to_string env true trait_ref ^ "." ^ method_name in - f ^ fun_suffix lp_id rg_id + f ^ fun_suffix lp_id | Pure fid -> pure_assumed_fun_id_to_string fid let unop_to_string (unop : unop) : string = @@ -746,7 +738,7 @@ and emeta_to_string (env : fmt_env) (meta : emeta) : string = let fun_decl_to_string (env : fmt_env) (def : fun_decl) : string = let env = { env with generics = def.signature.generics } in - let name = def.name ^ fun_suffix def.loop_id def.back_id in + let name = def.name ^ fun_suffix def.loop_id in let signature = fun_sig_to_string env def.signature in match def.body with | None -> "val " ^ name ^ " :\n " ^ signature diff --git a/compiler/Pure.ml b/compiler/Pure.ml index a879ba37..dd7a4acf 100644 --- a/compiler/Pure.ml +++ b/compiler/Pure.ml @@ -560,8 +560,7 @@ type fun_id_or_trait_method_ref = [@@deriving show, ord] (** A function id for a non-assumed function *) -type regular_fun_id = - fun_id_or_trait_method_ref * LoopId.id option * RegionGroupId.id option +type regular_fun_id = fun_id_or_trait_method_ref * LoopId.id option [@@deriving show, ord] (** A function identifier *) @@ -1078,7 +1077,6 @@ type fun_decl = { *) loop_id : LoopId.id option; (** [Some] if this definition was generated for a loop *) - back_id : RegionGroupId.id option; llbc_name : llbc_name; (** The original LLBC name. *) name : string; (** We use the name only for printing purposes (for debugging): diff --git a/compiler/PureMicroPasses.ml b/compiler/PureMicroPasses.ml index ec64df21..04bc90d7 100644 --- a/compiler/PureMicroPasses.ml +++ b/compiler/PureMicroPasses.ml @@ -925,156 +925,9 @@ let inline_useless_var_reassignments (ctx : trans_ctx) ~(inline_named : bool) in { def with body = Some body } -(** For the cases where we split the forward/backward functions. - - Given a forward or backward function call, is there, for every execution - path, a child backward function called later with exactly the same input - list prefix. We use this to filter useless function calls: if there are - such child calls, we can remove this one (in case its outputs are not - used). - We do this check because we can't simply remove function calls whose - outputs are not used, as they might fail. However, if a function fails, - its children backward functions then fail on the same inputs (ignoring - the additional inputs those receive). - - For instance, if we have: - {[ - fn f<'a>(x : &'a mut T); - ]} - - We often have things like this in the synthesized code: - {[ - _ <-- f@fwd x; - ... - nx <-- f@back'a x y; - ... - ]} - - If [f@back'a x y] fails, then necessarily [f@fwd x] also fails. - In this situation, we can remove the call [f@fwd x]. - *) -let expression_contains_child_call_in_all_paths (ctx : trans_ctx) - (id0 : fun_id_or_trait_method_ref) (lp_id0 : LoopId.id option) - (rg_id0 : T.RegionGroupId.id option) (generics0 : generic_args) - (args0 : texpression list) (e : texpression) : bool = - let check_call (fun_id1 : fun_or_op_id) (generics1 : generic_args) - (args1 : texpression list) : bool = - (* Check the fun_ids, to see if call1's function is a child of call0's function *) - match fun_id1 with - | Fun (FromLlbc (id1, lp_id1, rg_id1)) -> - (* Both are "regular" calls: check if they come from the same rust function *) - if id0 = id1 && lp_id0 = lp_id1 then - (* Same rust functions: check the regions hierarchy *) - let call1_is_child = - match (rg_id0, rg_id1) with - | None, _ -> - (* The function used in call0 is the forward function: the one - * used in call1 is necessarily a child *) - true - | Some _, None -> - (* Opposite of previous case *) - false - | Some rg_id0, Some rg_id1 -> - if rg_id0 = rg_id1 then true - else - (* We need to use the regions hierarchy *) - let regions_hierarchy = - let id0 = - match id0 with - | FunId fun_id -> fun_id - | TraitMethod (_, _, fun_decl_id) -> FRegular fun_decl_id - in - LlbcAstUtils.FunIdMap.find id0 - ctx.fun_ctx.regions_hierarchies - in - (* Compute the set of ancestors of the function in call1 *) - let call1_ancestors = - LlbcAstUtils.list_ancestor_region_groups regions_hierarchy - rg_id1 - in - (* Check if the function used in call0 is inside *) - T.RegionGroupId.Set.mem rg_id0 call1_ancestors - in - (* If call1 is a child, then we need to check if the input arguments - * used in call0 are a prefix of the input arguments used in call1 - * (note call1 being a child, it will likely consume strictly more - * given back values). - * *) - if call1_is_child then - let call1_args = - Collections.List.prefix (List.length args0) args1 - in - let args = List.combine args0 call1_args in - (* Note that the input values are expressions, *which may contain - * meta-values* (which we need to ignore). *) - let input_eq (v0, v1) = - PureUtils.remove_meta v0 = PureUtils.remove_meta v1 - in - (* Compare the generics and the prefix of the input arguments *) - generics0 = generics1 && List.for_all input_eq args - else (* Not a child *) - false - else (* Not the same function *) - false - | _ -> false - in - - let visitor = - object (self) - inherit [_] reduce_expression - method zero _ = false - method plus b0 b1 _ = b0 () && b1 () - - method! visit_texpression env e = - match e.e with - | Var _ | CVar _ | Const _ -> fun _ -> false - | StructUpdate _ -> - (* There shouldn't be monadic calls in structure updates - also - note that by returning [false] we are conservative: we might - *prevent* possible optimisations (i.e., filtering some function - calls), which is sound. *) - fun _ -> false - | Let (_, _, re, e) -> ( - match opt_destruct_function_call re with - | None -> fun () -> self#visit_texpression env e () - | Some (func1, generics1, args1) -> - let call_is_child = check_call func1 generics1 args1 in - if call_is_child then fun () -> true - else fun () -> self#visit_texpression env e ()) - | Lambda (_, e) -> self#visit_texpression env e - | App _ -> ( - fun () -> - match opt_destruct_function_call e with - | Some (func1, tys1, args1) -> check_call func1 tys1 args1 - | None -> false) - | Qualif _ -> - (* Note that this case includes functions without arguments *) - fun () -> false - | Meta (_, e) -> self#visit_texpression env e - | Loop loop -> - (* We only visit the *function end* *) - self#visit_texpression env loop.fun_end - | Switch (_, body) -> self#visit_switch_body env body - - method! visit_switch_body env body = - match body with - | If (e1, e2) -> - fun () -> - self#visit_texpression env e1 () - && self#visit_texpression env e2 () - | Match branches -> - fun () -> - List.for_all - (fun br -> self#visit_texpression env br.branch ()) - branches - end - in - visitor#visit_texpression () e () - (** Filter the useless assignments (removes the useless variables, filters the function calls) *) -let filter_useless (filter_monadic_calls : bool) (ctx : trans_ctx) - (def : fun_decl) : fun_decl = +let filter_useless (_ctx : trans_ctx) (def : fun_decl) : fun_decl = (* We first need a transformation on *left-values*, which filters the useless * variables and tells us whether the value contains any variable which has * not been replaced by [_] (in which case we need to keep the assignment, @@ -1166,30 +1019,8 @@ let filter_useless (filter_monadic_calls : bool) (ctx : trans_ctx) if not monadic then (* Not a monadic let-binding: simple case *) (e.e, fun _ -> used) - else - (* Monadic let-binding: trickier. - * We can filter if the right-expression is a function call, - * under some conditions. *) - match (filter_monadic_calls, opt_destruct_function_call re) with - | true, Some (Fun (FromLlbc (fid, lp_id, rg_id)), tys, args) -> - (* If we split the forward/backward functions. - - We need to check if there is a child call - see - the comments for: - [expression_contains_child_call_in_all_paths] *) - if not !Config.return_back_funs then - let has_child_call = - expression_contains_child_call_in_all_paths ctx fid - lp_id rg_id tys args e - in - if has_child_call then (* Filter *) - (e.e, fun _ -> used) - else (* No child call: don't filter *) - dont_filter () - else dont_filter () - | _ -> - (* Not an LLBC function call or not allowed to filter: we can't filter *) - dont_filter () + else (* Monadic let-binding: can't filter *) + dont_filter () else (* There are used variables: don't filter *) dont_filter () | Loop loop -> @@ -1442,22 +1273,6 @@ let simplify_aggregates (ctx : trans_ctx) (def : fun_decl) : fun_decl = let body = { body with body = body_exp } in { def with body = Some body } -(** Return [None] if the function is a backward function with no outputs (so - that we eliminate the definition which is useless). - - Note that the calls to such functions are filtered when translating from - symbolic to pure. Here, we remove the definitions altogether, because they - are now useless - *) -let filter_if_backward_with_no_outputs (def : fun_decl) : fun_decl option = - if - !Config.filter_useless_functions - && Option.is_some def.back_id - && def.signature.output = mk_result_ty mk_unit_ty - || def.signature.output = mk_unit_ty - then None - else Some def - (** Retrieve the loop definitions from the function definition. {!SymbolicToPure} generates an AST in which the loop bodies are part of @@ -1530,14 +1345,7 @@ let decompose_loops (_ctx : trans_ctx) (def : fun_decl) : info.num_inputs_with_fuel_no_state info.num_inputs_with_fuel_with_state in - let back_inputs = - if !Config.return_back_funs then [] - else - snd - (Collections.List.split_at fun_sig.inputs - info.num_inputs_with_fuel_with_state) - in - List.concat [ fuel; fwd_inputs; fwd_state; back_inputs ] + List.concat [ fuel; fwd_inputs; fwd_state ] in let output = loop.output_ty in @@ -1618,7 +1426,6 @@ let decompose_loops (_ctx : trans_ctx) (def : fun_decl) : kind = def.kind; num_loops; loop_id = Some loop.loop_id; - back_id = def.back_id; llbc_name = def.llbc_name; name = def.name; signature = loop_sig; @@ -1640,35 +1447,6 @@ let decompose_loops (_ctx : trans_ctx) (def : fun_decl) : let loops = List.map snd (LoopId.Map.bindings !loops) in (def, loops) -(** Return [false] if the forward function is useless and should be filtered. - - - a forward function with no output (comes from a Rust function with - unit return type) - - the function has mutable borrows as inputs (which is materialized - by the fact we generated backward functions which were not filtered). - - In such situation, every call to the Rust function will be translated to: - - a call to the forward function which returns nothing - - calls to the backward functions - As a failing backward function implies the forward function also fails, - we can filter the calls to the forward function, which thus becomes - useless. - In such situation, we can remove the forward function definition - altogether. - *) -let keep_forward (fwd : fun_and_loops) (backs : fun_and_loops list) : bool = - (* The question of filtering the forward functions arises only if we split - the forward/backward functions *) - if !Config.return_back_funs then true - else if - (* Note that at this point, the output types are no longer seen as tuples: - * they should be lists of length 1. *) - !Config.filter_useless_functions - && fwd.f.signature.output = mk_result_ty mk_unit_ty - && backs <> [] - then false - else true - (** Convert the unit variables to [()] if they are used as right-values or [_] if they are used as left values in patterns. *) let unit_vars_to_unit (def : fun_decl) : fun_decl = @@ -1724,19 +1502,17 @@ let eliminate_box_functions (_ctx : trans_ctx) (def : fun_decl) : fun_decl = * could have: [box_new f x]) * *) match fun_id with - | Fun (FromLlbc (FunId (FAssumed aid), _lp_id, rg_id)) -> ( - match (aid, rg_id) with - | BoxNew, _ -> - assert (rg_id = None); + | Fun (FromLlbc (FunId (FAssumed aid), _lp_id)) -> ( + match aid with + | BoxNew -> let arg, args = Collections.List.pop args in mk_apps arg args - | BoxFree, _ -> + | BoxFree -> assert (args = []); mk_unit_rvalue - | ( ( SliceIndexShared | SliceIndexMut | ArrayIndexShared - | ArrayIndexMut | ArrayToSliceShared | ArrayToSliceMut - | ArrayRepeat ), - _ ) -> + | SliceIndexShared | SliceIndexMut | ArrayIndexShared + | ArrayIndexMut | ArrayToSliceShared | ArrayToSliceMut + | ArrayRepeat -> super#visit_texpression env e) | _ -> super#visit_texpression env e) | _ -> super#visit_texpression env e @@ -1989,7 +1765,7 @@ let apply_end_passes_to_def (ctx : trans_ctx) (def : fun_decl) : fun_decl = (lazy ("eliminate_box_functions:\n\n" ^ fun_decl_to_string ctx def ^ "\n")); (* Filter the useless variables, assignments, function calls, etc. *) - let def = filter_useless !Config.filter_useless_monadic_calls ctx def in + let def = filter_useless ctx def in log#ldebug (lazy ("filter_useless:\n\n" ^ fun_decl_to_string ctx def ^ "\n")); (* Simplify the lets immediately followed by a return. @@ -2130,16 +1906,7 @@ let filter_loop_inputs (transl : pure_fun_translation list) : *) let all_decls = List.concat - (List.concat - (List.concat - (List.map - (fun { fwd; backs; _ } -> - [ fwd.f :: fwd.loops ] - :: List.map - (fun { f = back; loops = loops_back } -> - [ back :: loops_back ]) - backs) - transl))) + (List.concat (List.map (fun { f; loops } -> [ f :: loops ]) transl)) in let subgroups = ReorderDecls.group_reorder_fun_decls all_decls in @@ -2207,7 +1974,7 @@ let filter_loop_inputs (transl : pure_fun_translation list) : match e_app.e with | Qualif qualif -> ( match qualif.id with - | FunOrOp (Fun (FromLlbc (FunId fun_id', loop_id', _))) -> + | FunOrOp (Fun (FromLlbc (FunId fun_id', loop_id'))) -> if (fun_id', loop_id') = fun_id then ( (* For each argument, check if it is exactly the original input parameter. Note that there shouldn't be partial @@ -2357,8 +2124,7 @@ let filter_loop_inputs (transl : pure_fun_translation list) : match e_app.e with | Qualif qualif -> ( match qualif.id with - | FunOrOp (Fun (FromLlbc (FunId fun_id, loop_id, _))) - -> ( + | FunOrOp (Fun (FromLlbc (FunId fun_id, loop_id))) -> ( match FunLoopIdMap.find_opt (fun_id, loop_id) !used_map with @@ -2400,13 +2166,8 @@ let filter_loop_inputs (transl : pure_fun_translation list) : in let transl = List.map - (fun trans -> - let filter_fun_and_loops f = - { f = filter_in_one f.f; loops = List.map filter_in_one f.loops } - in - let fwd = filter_fun_and_loops trans.fwd in - let backs = List.map filter_fun_and_loops trans.backs in - { trans with fwd; backs }) + (fun f -> + { f = filter_in_one f.f; loops = List.map filter_in_one f.loops }) transl in @@ -2420,18 +2181,11 @@ let filter_loop_inputs (transl : pure_fun_translation list) : it thus returns the pair: (function def, loop defs). See {!decompose_loops} for more information. - Will return [None] if the function is a backward function with no outputs. - [ctx]: used only for printing. *) -let apply_passes_to_def (ctx : trans_ctx) (def : fun_decl) : - fun_and_loops option = +let apply_passes_to_def (ctx : trans_ctx) (def : fun_decl) : fun_and_loops = (* Debug *) - log#ldebug - (lazy - ("PureMicroPasses.apply_passes_to_def: " ^ def.name ^ " (" - ^ Print.option_to_string T.RegionGroupId.to_string def.back_id - ^ ")")); + log#ldebug (lazy ("PureMicroPasses.apply_passes_to_def: " ^ def.name)); log#ldebug (lazy ("original decl:\n\n" ^ fun_decl_to_string ctx def ^ "\n")); @@ -2451,29 +2205,13 @@ let apply_passes_to_def (ctx : trans_ctx) (def : fun_decl) : let def = remove_meta def in log#ldebug (lazy ("remove_meta:\n\n" ^ fun_decl_to_string ctx def ^ "\n")); - (* Remove the backward functions with no outputs. + (* Extract the loop definitions by removing the {!Loop} node *) + let def, loops = decompose_loops ctx def in - Note that the *calls* to those functions should already have been removed, - when translating from symbolic to pure. Here, we remove the definitions - altogether, because they are now useless *) - let name = def.name ^ PrintPure.fun_suffix def.loop_id def.back_id in - let opt_def = filter_if_backward_with_no_outputs def in - - match opt_def with - | None -> - log#ldebug (lazy ("filtered (backward with no outputs): " ^ name ^ "\n")); - None - | Some def -> - log#ldebug - (lazy ("not filtered (not backward with no outputs): " ^ name ^ "\n")); - - (* Extract the loop definitions by removing the {!Loop} node *) - let def, loops = decompose_loops ctx def in - - (* Apply the remaining passes *) - let f = apply_end_passes_to_def ctx def in - let loops = List.map (apply_end_passes_to_def ctx) loops in - Some { f; loops } + (* Apply the remaining passes *) + let f = apply_end_passes_to_def ctx def in + let loops = List.map (apply_end_passes_to_def ctx) loops in + { f; loops } (** Apply the micro-passes to a list of forward/backward translations. @@ -2489,18 +2227,11 @@ let apply_passes_to_def (ctx : trans_ctx) (def : fun_decl) : but convenient. *) let apply_passes_to_pure_fun_translations (ctx : trans_ctx) - (transl : (fun_decl * fun_decl list) list) : pure_fun_translation list = - let apply_to_one (trans : fun_decl * fun_decl list) : pure_fun_translation = - (* Apply the passes to the individual functions *) - let fwd, backs = trans in - let fwd = Option.get (apply_passes_to_def ctx fwd) in - let backs = List.filter_map (apply_passes_to_def ctx) backs in - (* Compute whether we need to filter the forward function or not *) - let keep_fwd = keep_forward fwd backs in - { keep_fwd; fwd; backs } - in - - let transl = List.map apply_to_one transl in + (transl : fun_decl list) : pure_fun_translation list = + (* Apply the micro-passes *) + let transl = List.map (apply_passes_to_def ctx) transl in - (* Filter the useless inputs in the loop functions *) + (* Filter the useless inputs in the loop functions (loops are initially + parameterized by *all* the symbolic values in the context, because + they may access any of them). *) filter_loop_inputs transl diff --git a/compiler/ReorderDecls.ml b/compiler/ReorderDecls.ml index 53c94ff4..f5443e03 100644 --- a/compiler/ReorderDecls.ml +++ b/compiler/ReorderDecls.ml @@ -5,11 +5,7 @@ open Pure (** The local logger *) let log = Logging.reorder_decls_log -type fun_id = { - def_id : FunDeclId.id; - lp_id : LoopId.id option; - rg_id : T.RegionGroupId.id option; -} +type fun_id = { def_id : FunDeclId.id; lp_id : LoopId.id option } [@@deriving show, ord] module FunIdOrderedType : OrderedType with type t = fun_id = struct @@ -43,11 +39,11 @@ let compute_body_fun_deps (e : texpression) : FunIdSet.t = | FunOrOp (Fun fid) -> ( match fid with | Pure _ -> () - | FromLlbc (fid, lp_id, rg_id) -> ( + | FromLlbc (fid, lp_id) -> ( match fid with | FunId (FAssumed _) -> () | TraitMethod (_, _, fid) | FunId (FRegular fid) -> - let id = { def_id = fid; lp_id; rg_id } in + let id = { def_id = fid; lp_id } in ids := FunIdSet.add id !ids)) end in @@ -71,7 +67,7 @@ let group_reorder_fun_decls (decls : fun_decl list) : (bool * fun_decl list) list = let module IntMap = MakeMap (OrderedInt) in let get_fun_id (decl : fun_decl) : fun_id = - { def_id = decl.def_id; lp_id = decl.loop_id; rg_id = decl.back_id } + { def_id = decl.def_id; lp_id = decl.loop_id } in (* Compute the list/set of identifiers *) let idl = List.map get_fun_id decls in diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml index 859d6f17..2db5f66c 100644 --- a/compiler/SymbolicToPure.ml +++ b/compiler/SymbolicToPure.ml @@ -2035,7 +2035,7 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) : | S.Fun (fid, call_id) -> (* Regular function call *) let fid_t = translate_fun_id_or_trait_method_ref ctx fid in - let func = Fun (FromLlbc (fid_t, None, None)) in + let func = Fun (FromLlbc (fid_t, None)) in (* Retrieve the effect information about this function (can fail, * takes a state as input, etc.) *) let effect_info = get_fun_effect_info ctx fid None None in @@ -2539,8 +2539,6 @@ and translate_end_abstraction_loop (ectx : C.eval_ctx) (abs : V.abs) get_fun_effect_info ctx (FunId fun_id) (Some vloop_id) (Some rg_id) in let loop_info = LoopId.Map.find loop_id ctx.loops in - let generics = loop_info.generics in - let fwd_inputs = Option.get loop_info.forward_inputs in (* Retrieve the additional backward inputs. Note that those are actually the backward inputs of the function we are synthesizing (and that we need to *transmit* to the loop backward function): they are not the @@ -2582,10 +2580,6 @@ and translate_end_abstraction_loop (ectx : C.eval_ctx) (abs : V.abs) (fun (arg, mp) -> mk_opt_mplace_texpression mp arg) (List.combine inputs args_mplaces) in - let input_tys = (List.map (fun (x : texpression) -> x.ty)) args in - let ret_ty = - if effect_info.can_fail then mk_result_ty output.ty else output.ty - in (* Create the expression for the function: - it is either a call to a top-level function, if we split the forward/backward functions @@ -3218,7 +3212,7 @@ and translate_forward_end (ectx : C.eval_ctx) let out_pat = mk_simpl_tuple_pattern out_pats in let loop_call = - let fun_id = Fun (FromLlbc (FunId fid, Some loop_id, None)) in + let fun_id = Fun (FromLlbc (FunId fid, Some loop_id)) in let func = { id = FunOrOp fun_id; generics = loop_info.generics } in let input_tys = (List.map (fun (x : texpression) -> x.ty)) args in let ret_ty = @@ -3567,9 +3561,6 @@ let translate_fun_decl (ctx : bs_ctx) (body : S.expression option) : fun_decl = let name = name_to_string ctx llbc_name in (* Translate the signature *) let signature = translate_fun_sig_from_decomposed ctx.sg in - let regions_hierarchy = - FunIdMap.find (FRegular def_id) ctx.fun_ctx.regions_hierarchies - in (* Translate the body, if there is *) let body = match body with diff --git a/compiler/Translate.ml b/compiler/Translate.ml index 55a94302..c12de045 100644 --- a/compiler/Translate.ml +++ b/compiler/Translate.ml @@ -1,5 +1,4 @@ open Interpreter -open Expressions open Types open Values open LlbcAst @@ -49,8 +48,6 @@ let translate_function_to_pure (trans_ctx : trans_ctx) log#ldebug (lazy ("translate_function_to_pure: " ^ name_to_string trans_ctx fdef.name)); - let def_id = fdef.def_id in - (* Compute the symbolic ASTs, if the function is transparent *) let symbolic_trans = translate_function_to_symbolics trans_ctx fdef in @@ -124,20 +121,7 @@ let translate_function_to_pure (trans_ctx : trans_ctx) SymbolicToPure.translate_fun_sig_from_decl_to_decomposed trans_ctx fdef in - let regions_hierarchy = - LlbcAstUtils.FunIdMap.find (FRegular def_id) fun_ctx.regions_hierarchies - in - - let var_counter, back_state_vars = - if !Config.return_back_funs then (var_counter, []) - else - List.fold_left_map - (fun var_counter (region_vars : region_var_group) -> - let gid = region_vars.id in - let var, var_counter = Pure.VarId.fresh var_counter in - (var_counter, (gid, var))) - var_counter regions_hierarchy - in + let var_counter, back_state_vars = (var_counter, []) in let back_state_vars = RegionGroupId.Map.of_list back_state_vars in let ctx = @@ -195,28 +179,7 @@ let translate_function_to_pure (trans_ctx : trans_ctx) in (* Add the backward inputs *) - let ctx, backward_inputs_no_state, backward_inputs_with_state = - if !Config.return_back_funs then (ctx, [], []) - else - let ctx, inputs_no_with_state = - List.fold_left_map - (fun ctx (region_vars : region_var_group) -> - let gid = region_vars.id in - let back_sg = RegionGroupId.Map.find gid sg.back_sg in - let ctx, no_state = - SymbolicToPure.fresh_vars back_sg.inputs_no_state ctx - in - let ctx, with_state = - SymbolicToPure.fresh_vars back_sg.inputs ctx - in - (ctx, ((gid, no_state), (gid, with_state)))) - ctx regions_hierarchy - in - let inputs_no_state, inputs_with_state = - List.split inputs_no_with_state - in - (ctx, inputs_no_state, inputs_with_state) - in + let backward_inputs_no_state, backward_inputs_with_state = ([], []) in let backward_inputs_no_state = RegionGroupId.Map.of_list backward_inputs_no_state in @@ -225,40 +188,10 @@ let translate_function_to_pure (trans_ctx : trans_ctx) in let ctx = { ctx with backward_inputs_no_state; backward_inputs_with_state } in - (* Translate the forward function *) - let pure_forward = - match symbolic_trans with - | None -> SymbolicToPure.translate_fun_decl ctx None - | Some (_, ast) -> SymbolicToPure.translate_fun_decl ctx (Some ast) - in - - (* Translate the backward functions, if we split the forward and backward functions *) - let translate_backward (rg : region_var_group) : Pure.fun_decl = - (* For the backward inputs/outputs initialization: we use the fact that - * there are no nested borrows for now, and so that the region groups - * can't have parents *) - assert (rg.parents = []); - let back_id = rg.id in - - match symbolic_trans with - | None -> - (* Initialize the context *) - let ctx = { ctx with bid = Some back_id } in - (* Translate *) - SymbolicToPure.translate_fun_decl ctx None - | Some (_, symbolic) -> - (* Initialize the context *) - let ctx = { ctx with bid = Some back_id } in - (* Translate *) - SymbolicToPure.translate_fun_decl ctx (Some symbolic) - in - let pure_backwards = - if !Config.return_back_funs then [] - else List.map translate_backward regions_hierarchy - in - - (* Return *) - (pure_forward, pure_backwards) + (* Translate the function *) + match symbolic_trans with + | None -> SymbolicToPure.translate_fun_decl ctx None + | Some (_, ast) -> SymbolicToPure.translate_fun_decl ctx (Some ast) (* TODO: factor out the return type *) let translate_crate_to_pure (crate : crate) : @@ -513,9 +446,8 @@ let export_global (fmt : Format.formatter) (config : gen_config) (ctx : gen_ctx) let global_decls = ctx.trans_ctx.global_ctx.global_decls in let global = GlobalDeclId.Map.find id global_decls in let trans = FunDeclId.Map.find global.body ctx.trans_funs in - assert (trans.fwd.loops = []); - assert (trans.backs = []); - let body = trans.fwd.f in + assert (trans.loops = []); + let body = trans.f in let is_opaque = Option.is_none body.Pure.body in (* Check if we extract the global *) @@ -643,7 +575,7 @@ let export_functions_group (fmt : Format.formatter) (config : gen_config) let funs_map = builtin_funs_map () in List.map (fun (trans : pure_fun_translation) -> - match_name_find_opt ctx.trans_ctx trans.fwd.f.llbc_name funs_map <> None) + match_name_find_opt ctx.trans_ctx trans.f.llbc_name funs_map <> None) pure_ls in @@ -660,7 +592,7 @@ let export_functions_group (fmt : Format.formatter) (config : gen_config) (* Extract the decrease clauses template bodies *) if config.extract_template_decreases_clauses then List.iter - (fun { fwd; _ } -> + (fun f -> (* We only generate decreases clauses for the forward functions, because the termination argument should only depend on the forward inputs. The backward functions thus use the same decreases clauses as the @@ -687,27 +619,14 @@ let export_functions_group (fmt : Format.formatter) (config : gen_config) raise (Failure "HOL4 doesn't have decreases/termination clauses") in - extract_decrease fwd.f; - List.iter extract_decrease fwd.loops) + extract_decrease f.f; + List.iter extract_decrease f.loops) pure_ls; - (* Concatenate the function definitions, filtering the useless forward - * functions. *) + (* Flatten the translated functions (concatenate the functions with + the declarations introduced for the loops) *) let decls = - List.concat - (List.map - (fun { keep_fwd; fwd; backs } -> - let fwd = - if keep_fwd then List.append fwd.loops [ fwd.f ] else [] - in - let backs : Pure.fun_decl list = - List.concat - (List.map - (fun back -> List.append back.loops [ back.f ]) - backs) - in - List.append fwd backs) - pure_ls) + List.concat (List.map (fun f -> List.append f.loops [ f.f ]) pure_ls) in (* Extract the function definitions *) @@ -724,9 +643,7 @@ let export_functions_group (fmt : Format.formatter) (config : gen_config) (* Insert unit tests if necessary *) if config.test_trans_unit_functions then List.iter - (fun trans -> - if trans.keep_fwd then - Extract.extract_unit_test_if_unit_fun ctx fmt trans.fwd.f) + (fun trans -> Extract.extract_unit_test_if_unit_fun ctx fmt trans.f) pure_ls (** Export a trait declaration. *) @@ -812,7 +729,7 @@ let extract_definitions (fmt : Format.formatter) (config : gen_config) extract their type directly in the records we generate for the trait declarations themselves, there is no point in having separate type definitions) *) - match pure_fun.fwd.f.Pure.kind with + match pure_fun.f.Pure.kind with | TraitMethodDecl _ -> () | _ -> (* Translate *) @@ -1001,18 +918,18 @@ let translate_crate (filename : string) (dest_dir : string) (crate : crate) : * whether we should generate a decrease clause or not. *) let rec_functions = List.map - (fun { fwd; _ } -> - let fwd_f = - if fwd.f.Pure.signature.fwd_info.effect_info.is_rec then - [ (fwd.f.def_id, None) ] + (fun trans -> + let f = + if trans.f.Pure.signature.fwd_info.effect_info.is_rec then + [ (trans.f.def_id, None) ] else [] in - let loop_fwds = + let loops = List.map (fun (def : Pure.fun_decl) -> [ (def.def_id, def.loop_id) ]) - fwd.loops + trans.loops in - fwd_f :: loop_fwds) + f :: loops) trans_funs in let rec_functions : PureUtils.fun_loop_id list = @@ -1028,7 +945,7 @@ let translate_crate (filename : string) (dest_dir : string) (crate : crate) : let trans_funs : pure_fun_translation FunDeclId.Map.t = FunDeclId.Map.of_list (List.map - (fun (trans : pure_fun_translation) -> (trans.fwd.f.def_id, trans)) + (fun (trans : pure_fun_translation) -> (trans.f.def_id, trans)) trans_funs) in @@ -1052,7 +969,6 @@ let translate_crate (filename : string) (dest_dir : string) (crate : crate) : names_maps; indent_incr = 2; use_dep_ite = !Config.backend = Lean && !Config.extract_decreases_clauses; - fun_name_info = PureUtils.RegularFunIdMap.empty; trait_decl_id = None (* None by default *); is_provided_method = false (* false by default *); trans_trait_decls; @@ -1082,7 +998,6 @@ let translate_crate (filename : string) (dest_dir : string) (crate : crate) : (fun ctx (trans : pure_fun_translation) -> (* If requested by the user, register termination measures and decreases proofs for all the recursive functions *) - let fwd_def = trans.fwd.f in let gen_decr_clause (def : Pure.fun_decl) = !Config.extract_decreases_clauses && PureUtils.FunLoopIdSet.mem @@ -1091,7 +1006,7 @@ let translate_crate (filename : string) (dest_dir : string) (crate : crate) : in (* Register the names, only if the function is not a global body - * those are handled later *) - let is_global = fwd_def.Pure.is_global_decl_body in + let is_global = trans.f.Pure.is_global_decl_body in if is_global then ctx else Extract.extract_fun_decl_register_names ctx gen_decr_clause trans) ctx @@ -1171,13 +1086,7 @@ let translate_crate (filename : string) (dest_dir : string) (crate : crate) : let exe_dir = Filename.dirname Sys.argv.(0) in let primitives_src_dest = match !Config.backend with - | FStar -> - let src = - if !Config.return_back_funs then - "/backends/fstar/merge/Primitives.fst" - else "/backends/fstar/split/Primitives.fst" - in - Some (src, "Primitives.fst") + | FStar -> Some ("/backends/fstar/merge/Primitives.fst", "Primitives.fst") | Coq -> Some ("/backends/coq/Primitives.v", "Primitives.v") | Lean -> None | HOL4 -> None diff --git a/compiler/TranslateCore.ml b/compiler/TranslateCore.ml index 88438872..05877b5a 100644 --- a/compiler/TranslateCore.ml +++ b/compiler/TranslateCore.ml @@ -8,19 +8,8 @@ let log = Logging.translate_log type trans_ctx = decls_ctx [@@deriving show] type fun_and_loops = { f : Pure.fun_decl; loops : Pure.fun_decl list } -type pure_fun_translation_no_loops = Pure.fun_decl * Pure.fun_decl list - -type pure_fun_translation = { - keep_fwd : bool; - (** Should we extract the forward function? - - If the forward function returns `()` and there is exactly one - backward function, we may merge the forward into the backward - function and thus don't extract the forward function)? - *) - fwd : fun_and_loops; - backs : fun_and_loops list; -} +type pure_fun_translation_no_loops = Pure.fun_decl +type pure_fun_translation = fun_and_loops let trans_ctx_to_fmt_env (ctx : trans_ctx) : Print.fmt_env = Print.Contexts.decls_ctx_to_fmt_env ctx |