diff options
Diffstat (limited to 'compiler')
-rw-r--r-- | compiler/Config.ml | 107 | ||||
-rw-r--r-- | compiler/Contexts.ml | 2 | ||||
-rw-r--r-- | compiler/Extract.ml | 444 | ||||
-rw-r--r-- | compiler/ExtractBase.ml | 243 | ||||
-rw-r--r-- | compiler/ExtractBuiltin.ml | 138 | ||||
-rw-r--r-- | compiler/ExtractTypes.ml | 2 | ||||
-rw-r--r-- | compiler/Main.ml | 9 | ||||
-rw-r--r-- | compiler/PrintPure.ml | 18 | ||||
-rw-r--r-- | compiler/Pure.ml | 4 | ||||
-rw-r--r-- | compiler/PureMicroPasses.ml | 331 | ||||
-rw-r--r-- | compiler/ReorderDecls.ml | 12 | ||||
-rw-r--r-- | compiler/SymbolicToPure.ml | 949 | ||||
-rw-r--r-- | compiler/Translate.ml | 145 | ||||
-rw-r--r-- | compiler/TranslateCore.ml | 15 |
14 files changed, 696 insertions, 1723 deletions
diff --git a/compiler/Config.ml b/compiler/Config.ml index 3b0070c0..af0e62d1 100644 --- a/compiler/Config.ml +++ b/compiler/Config.ml @@ -92,69 +92,6 @@ let loop_fixed_point_max_num_iters = 2 (** {1 Translation} *) -(** If true, do not define separate forward/backward functions, but make the - forward functions return the backward function. - - Example: - {[ - (* Rust *) - pub fn list_nth<'a, T>(l: &'a mut List<T>, i: u32) -> &'a mut T { - match l { - List::Nil => { - panic!() - } - List::Cons(x, tl) => { - if i == 0 { - x - } else { - list_nth(tl, i - 1) - } - } - } - } - - (* Translation, if return_back_funs = false *) - def list_nth (T : Type) (l : List T) (i : U32) : Result T := - match l with - | List.Cons x tl => - if i = 0#u32 - then Result.ret x - else do - let i0 ← i - 1#u32 - list_nth T tl i0 - | List.Nil => Result.fail .panic - - def list_nth_back - (T : Type) (l : List T) (i : U32) (ret : T) : Result (List T) := - match l with - | List.Cons x tl => - if i = 0#u32 - then Result.ret (List.Cons ret tl) - else - do - let i0 ← i - 1#u32 - let tl0 ← list_nth_back T tl i0 ret - Result.ret (List.Cons x tl0) - | List.Nil => Result.fail .panic - - (* Translation, if return_back_funs = true *) - def list_nth (T: Type) (ls : List T) (i : U32) : - Result (T × (T → Result (List T))) := - match ls with - | List.Cons x tl => - if i = 0#u32 - then Result.ret (x, (λ ret => return (ret :: ls))) - else do - let i0 ← i - 1#u32 - let (x, back) ← list_nth ls i0 - Return.ret (x, - (λ ret => do - let ls ← back ret - return (x :: ls))) - ]} - *) -let return_back_funs = ref true - (** Forbids using field projectors for structures. If we don't use field projectors, whenever we symbolically expand a structure @@ -326,50 +263,6 @@ let decompose_nested_let_patterns = ref false *) let unfold_monadic_let_bindings = ref false -(** Controls whether we try to filter the calls to monadic functions - (which can fail) when their outputs are not used. - - The useless calls are calls to backward functions which have no outputs. - This case happens if the original Rust function only takes *shared* borrows - as inputs, and is thus pretty common. - - We are allowed to do this only because in this specific case, - the backward function fails *exactly* when the forward function fails - (they actually do exactly the same thing, the only difference being - that the forward function can potentially return a value), and upon - reaching the place where we should introduce a call to the backward - function, we know we have introduced a call to the forward function. - - Also note that in general, backward functions "do more things" than - forward functions, and have more opportunities to fail (even though - in the generated code, calls to the backward functions should fail - exactly when the corresponding, previous call to the forward functions - failed). - - This optimization is done in {!SymbolicToPure}. We might want to move it to - the micro-passes subsequent to the translation from symbolic to pure, but it - is really super easy to do it when going from symbolic to pure. Note that - we later filter the useless *forward* calls in the micro-passes, where it is - more natural to do. - - See the comments for {!PureMicroPasses.expression_contains_child_call_in_all_paths} - for additional explanations. - *) -let filter_useless_monadic_calls = ref true - -(** If {!filter_useless_monadic_calls} is activated, some functions - become useless: if this option is true, we don't extract them. - - The calls to functions which always get filtered are: - - the forward functions with unit return value - - the backward functions which don't output anything (backward - functions coming from rust functions with no mutable borrows - as input values - note that if a function doesn't take mutable - borrows as inputs, it can't return mutable borrows; we actually - dynamically check for that). - *) -let filter_useless_functions = ref true - (** Simplify the forward/backward functions, in case we merge them (i.e., the forward functions return the backward functions). diff --git a/compiler/Contexts.ml b/compiler/Contexts.ml index b1dd9553..54411fd5 100644 --- a/compiler/Contexts.ml +++ b/compiler/Contexts.ml @@ -109,6 +109,8 @@ let reset_global_counters () = region_id_counter := RegionId.generator_zero; abstraction_id_counter := AbstractionId.generator_zero; loop_id_counter := LoopId.generator_zero; + (* We want the loop id to start at 1 *) + let _ = fresh_loop_id () in fun_call_id_counter := FunCallId.generator_zero; dummy_var_id_counter := DummyVarId.generator_zero diff --git a/compiler/Extract.ml b/compiler/Extract.ml index dbca4f8f..794a1bfa 100644 --- a/compiler/Extract.ml +++ b/compiler/Extract.ml @@ -9,8 +9,7 @@ open TranslateCore open Config include ExtractTypes -(** Compute the names for all the pure functions generated from a rust function - (forward function and backward functions). +(** Compute the names for all the pure functions generated from a rust function. *) let extract_fun_decl_register_names (ctx : extraction_ctx) (has_decreases_clause : fun_decl -> bool) (def : pure_fun_translation) : @@ -19,63 +18,36 @@ let extract_fun_decl_register_names (ctx : extraction_ctx) method implementations): we do not need to refer to them directly. We will only use their type for the fields of the records we generate for the trait declarations *) - match def.fwd.f.kind with + match def.f.kind with | TraitMethodDecl _ -> ctx | _ -> ( (* Check if the function is builtin *) let builtin = let open ExtractBuiltin in let funs_map = builtin_funs_map () in - match_name_find_opt ctx.trans_ctx def.fwd.f.llbc_name funs_map + match_name_find_opt ctx.trans_ctx def.f.llbc_name funs_map in (* Use the builtin names if necessary *) match builtin with - | Some (filter_info, info) -> - (* Register the filtering information, if there is *) + | Some (filter_info, fun_info) -> + (* Builtin function: register the filtering information, if there is *) let ctx = match filter_info with | Some keep -> { ctx with funs_filter_type_args_map = - FunDeclId.Map.add def.fwd.f.def_id keep + FunDeclId.Map.add def.f.def_id keep ctx.funs_filter_type_args_map; } | _ -> ctx in - let funs = - if !Config.return_back_funs then [ def.fwd.f ] - else - let backs = List.map (fun f -> f.f) def.backs in - if def.keep_fwd then def.fwd.f :: backs else backs - in - List.fold_left - (fun ctx (f : fun_decl) -> - let open ExtractBuiltin in - let fun_id = - (Pure.FunId (FRegular f.def_id), f.loop_id, f.back_id) - in - let fun_info = - List.find_opt - (fun (x : builtin_fun_info) -> x.rg = f.back_id) - info - in - match fun_info with - | Some fun_info -> - ctx_add (FunId (FromLlbc fun_id)) fun_info.extract_name ctx - | None -> - raise - (Failure - ("Not found: " - ^ name_to_string ctx f.llbc_name - ^ ", " - ^ Print.option_to_string Pure.show_loop_id f.loop_id - ^ Print.option_to_string Pure.show_region_group_id - f.back_id))) - ctx funs + let f = def.f in + let open ExtractBuiltin in + let fun_id = (Pure.FunId (FRegular f.def_id), f.loop_id) in + ctx_add (FunId (FromLlbc fun_id)) fun_info.extract_name ctx | None -> - let fwd = def.fwd in - let backs = def.backs in + (* Not builtin *) (* Register the decrease clauses, if necessary *) let register_decreases ctx def = if has_decreases_clause def then @@ -88,21 +60,15 @@ let extract_fun_decl_register_names (ctx : extraction_ctx) | Lean -> ctx_add_decreases_proof def ctx else ctx in - let ctx = - List.fold_left register_decreases ctx (fwd.f :: fwd.loops) - in - let register_fun ctx f = ctx_add_fun_decl def f ctx in + (* We have to register the function itself, and the loops it + may contain (which are extracted as functions) *) + let funs = def.f :: def.loops in + (* Register the decrease clauses *) + let ctx = List.fold_left register_decreases ctx funs in + (* Register the name of the function and the loops *) + let register_fun ctx f = ctx_add_fun_decl f ctx in let register_funs ctx fl = List.fold_left register_fun ctx fl in - (* Register the names of the forward functions *) - let ctx = - if def.keep_fwd then register_funs ctx (fwd.f :: fwd.loops) else ctx - in - (* Register the names of the backward functions *) - List.fold_left - (fun ctx { f = back; loops = loop_backs } -> - let ctx = register_fun ctx back in - register_funs ctx loop_backs) - ctx backs) + register_funs ctx funs) (** Simply add the global name to the context. *) let extract_global_decl_register_names (ctx : extraction_ctx) @@ -230,7 +196,7 @@ let fun_builtin_filter_types (id : FunDeclId.id) (types : 'a list) let decl = FunDeclId.Map.find id ctx.trans_funs in let err = "Ill-formed builtin information for function " - ^ name_to_string ctx decl.fwd.f.llbc_name + ^ name_to_string ctx decl.f.llbc_name ^ ": " ^ string_of_int (List.length filter) ^ " filtering arguments provided for " @@ -460,8 +426,7 @@ and extract_function_call (ctx : extraction_ctx) (fmt : F.formatter) ]} *) (match fun_id with - | FromLlbc - (TraitMethod (trait_ref, method_name, _fun_decl_id), lp_id, rg_id) -> + | FromLlbc (TraitMethod (trait_ref, method_name, _fun_decl_id), lp_id) -> (* We have to check whether the trait method is required or provided *) let trait_decl_id = trait_ref.trait_decl_ref.trait_decl_id in let trait_decl = @@ -477,7 +442,7 @@ and extract_function_call (ctx : extraction_ctx) (fmt : F.formatter) extract_trait_ref ctx fmt TypeDeclId.Set.empty true trait_ref; let fun_name = ctx_get_trait_method trait_ref.trait_decl_ref.trait_decl_id - method_name rg_id ctx + method_name ctx in let add_brackets (s : string) = if !backend = Coq then "(" ^ s ^ ")" else s @@ -486,9 +451,7 @@ and extract_function_call (ctx : extraction_ctx) (fmt : F.formatter) else (* Provided method: we see it as a regular function call, and use the function name *) - let fun_id = - FromLlbc (FunId (FRegular method_id.id), lp_id, rg_id) - in + let fun_id = FromLlbc (FunId (FRegular method_id.id), lp_id) in let fun_name = ctx_get_function fun_id ctx in F.pp_print_string fmt fun_name; @@ -513,7 +476,7 @@ and extract_function_call (ctx : extraction_ctx) (fmt : F.formatter) *) let types = match fun_id with - | FromLlbc (FunId (FRegular id), _, _) -> + | FromLlbc (FunId (FRegular id), _) -> fun_builtin_filter_types id generics.types ctx | _ -> Result.Ok generics.types in @@ -1392,11 +1355,6 @@ let extract_template_lean_termination_and_decreasing (ctx : extraction_ctx) let extract_fun_comment (ctx : extraction_ctx) (fmt : F.formatter) (def : fun_decl) : unit = - let { keep_fwd; num_backs } = - PureUtils.RegularFunIdMap.find - (Pure.FunId (FRegular def.def_id), def.loop_id, def.back_id) - ctx.fun_name_info - in let comment_pre = "[" ^ name_to_string ctx def.llbc_name ^ "]:" in let comment = let loop_comment = @@ -1404,23 +1362,7 @@ let extract_fun_comment (ctx : extraction_ctx) (fmt : F.formatter) | None -> "" | Some id -> " loop " ^ LoopId.to_string id ^ ":" in - let fwd_back_comment = - match def.back_id with - | None -> if !Config.return_back_funs then [] else [ "forward function" ] - | Some id -> - (* Check if there is only one backward function, and no forward function *) - if (not keep_fwd) && num_backs = 1 then - [ - "merged forward/backward function"; - "(there is a single backward function, and the forward function \ - returns ())"; - ] - else [ "backward function " ^ T.RegionGroupId.to_string id ] - in - match fwd_back_comment with - | [] -> [ comment_pre ^ loop_comment ] - | [ s ] -> [ comment_pre ^ loop_comment ^ " " ^ s ] - | s :: sl -> (comment_pre ^ loop_comment ^ " " ^ s) :: sl + [ comment_pre ^ loop_comment ] in extract_comment_with_span fmt comment def.meta.span @@ -1435,9 +1377,7 @@ let extract_fun_decl_gen (ctx : extraction_ctx) (fmt : F.formatter) (kind : decl_kind) (has_decreases_clause : bool) (def : fun_decl) : unit = assert (not def.is_global_decl_body); (* Retrieve the function name *) - let def_name = - ctx_get_local_function def.def_id def.loop_id def.back_id ctx - in + let def_name = ctx_get_local_function def.def_id def.loop_id ctx in (* Add a break before *) if !backend <> HOL4 || not (decl_is_first_from_group kind) then F.pp_print_break fmt 0 0; @@ -1681,9 +1621,7 @@ let extract_fun_decl_gen (ctx : extraction_ctx) (fmt : F.formatter) let extract_fun_decl_hol4_opaque (ctx : extraction_ctx) (fmt : F.formatter) (def : fun_decl) : unit = (* Retrieve the definition name *) - let def_name = - ctx_get_local_function def.def_id def.loop_id def.back_id ctx - in + let def_name = ctx_get_local_function def.def_id def.loop_id ctx in assert (def.signature.generics.const_generics = []); (* Add the type/const gen parameters - note that we need those bindings only for the generation of the type (they are not top-level) *) @@ -1870,7 +1808,6 @@ let extract_global_decl_hol4_opaque (ctx : extraction_ctx) (fmt : F.formatter) let extract_global_decl (ctx : extraction_ctx) (fmt : F.formatter) (global : A.global_decl) (body : fun_decl) (interface : bool) : unit = assert body.is_global_decl_body; - assert (Option.is_none body.back_id); assert (body.signature.inputs = []); assert (body.signature.generics = empty_generic_params); @@ -1883,9 +1820,7 @@ let extract_global_decl (ctx : extraction_ctx) (fmt : F.formatter) let decl_name = ctx_get_global global.def_id ctx in let body_name = - ctx_get_function - (FromLlbc (Pure.FunId (FRegular global.body), None, None)) - ctx + ctx_get_function (FromLlbc (Pure.FunId (FRegular global.body), None)) ctx in let decl_ty, body_ty = @@ -2058,80 +1993,45 @@ let extract_trait_decl_method_names (ctx : extraction_ctx) let required_methods = trait_decl.required_methods in (* Compute the names *) let method_names = - (* We add one field per required forward/backward function *) - let get_funs_for_id (id : fun_decl_id) : fun_decl list = - let trans : pure_fun_translation = FunDeclId.Map.find id ctx.trans_funs in - if !Config.return_back_funs then [ trans.fwd.f ] - else List.map (fun f -> f.f) (trans.fwd :: trans.backs) - in match builtin_info with | None -> - (* We add one field per required forward/backward function *) - let compute_item_names (item_name : string) (id : fun_decl_id) : - string * (RegionGroupId.id option * string) list = - let compute_fun_name (f : fun_decl) : RegionGroupId.id option * string - = - (* We do something special to reuse the [ctx_compute_fun_decl] - function. TODO: make it cleaner. *) - let llbc_name : Types.name = - [ Types.PeIdent (item_name, Disambiguator.zero) ] - in - let f = { f with llbc_name } in - let trans = A.FunDeclId.Map.find f.def_id ctx.trans_funs in - let name = ctx_compute_fun_name trans f ctx in - (* Add a prefix if necessary *) - let name = - if !Config.record_fields_short_names then name - else ctx_compute_trait_decl_name ctx trait_decl ^ "_" ^ name - in - (f.back_id, name) + (* Not a builtin function *) + let compute_item_name (item_name : string) (id : fun_decl_id) : + string * string = + let trans : pure_fun_translation = + FunDeclId.Map.find id ctx.trans_funs + in + let f = trans.f in + (* We do something special to reuse the [ctx_compute_fun_decl] + function. TODO: make it cleaner. *) + let llbc_name : Types.name = + [ Types.PeIdent (item_name, Disambiguator.zero) ] + in + let f = { f with llbc_name } in + let name = ctx_compute_fun_name f ctx in + (* Add a prefix if necessary *) + let name = + if !Config.record_fields_short_names then name + else ctx_compute_trait_decl_name ctx trait_decl ^ "_" ^ name in - let funs = get_funs_for_id id in - (item_name, List.map compute_fun_name funs) + (item_name, name) in - List.map (fun (name, id) -> compute_item_names name id) required_methods + List.map (fun (name, id) -> compute_item_name name id) required_methods | Some info -> + (* This is a builtin *) let funs_map = StringMap.of_list info.methods in List.map - (fun (item_name, fun_id) -> + (fun (item_name, _) -> let open ExtractBuiltin in let info = StringMap.find item_name funs_map in - let trans_funs = get_funs_for_id fun_id in - let find (trans_fun : fun_decl) = - let info = - List.find_opt - (fun (info : builtin_fun_info) -> info.rg = trans_fun.back_id) - info - in - match info with - | Some info -> (info.rg, info.extract_name) - | None -> - let err = - "Ill-formed builtin information for trait decl \"" - ^ name_to_string ctx trait_decl.llbc_name - ^ "\", method \"" ^ item_name - ^ "\": could not find name for region " - ^ Print.option_to_string Pure.show_region_group_id - trans_fun.back_id - in - log#serror err; - if !Config.fail_hard then raise (Failure err) - else (trans_fun.back_id, "%ERROR_BUILTIN_NAME_NOT_FOUND%") - in - let rg_with_name_list = List.map find trans_funs in - (item_name, rg_with_name_list)) + let fun_name = info.extract_name in + (item_name, fun_name)) required_methods in (* Register the names *) List.fold_left - (fun ctx (item_name, funs) -> - (* We add one field per required forward/backward function *) - List.fold_left - (fun ctx (rg, fun_name) -> - ctx_add - (TraitMethodId (trait_decl.def_id, item_name, rg)) - fun_name ctx) - ctx funs) + (fun ctx (item_name, fun_name) -> + ctx_add (TraitMethodId (trait_decl.def_id, item_name)) fun_name ctx) ctx method_names (** Similar to {!extract_type_decl_register_names} *) @@ -2263,46 +2163,41 @@ let extract_trait_decl_method_items (ctx : extraction_ctx) (fmt : F.formatter) (* Lookup the definition *) let trans = A.FunDeclId.Map.find id ctx.trans_funs in (* Extract the items *) - let funs = if trans.keep_fwd then trans.fwd :: trans.backs else trans.backs in - let extract_method (f : fun_and_loops) = - let f = f.f in - let fun_name = ctx_get_trait_method decl.def_id item_name f.back_id ctx in - let ty () = - (* Extract the generics *) - (* We need to add the generics specific to the method, by removing those - which actually apply to the trait decl *) - let generics = - let drop_trait_clauses = false in - generic_params_drop_prefix ~drop_trait_clauses decl.generics - f.signature.generics - in - (* Note that we do not filter the LLBC generic parameters. - This is ok because: - - we only use them to find meaningful names for the trait clauses - - we only generate trait clauses for the clauses we find in the - pure generics *) - let ctx, type_params, cg_params, trait_clauses = - ctx_add_generic_params f.llbc_name f.signature.llbc_generics generics - ctx - in - let backend_uses_forall = - match !backend with Coq | Lean -> true | FStar | HOL4 -> false - in - let generics_not_empty = generics <> empty_generic_params in - let use_forall = generics_not_empty && backend_uses_forall in - let use_arrows = generics_not_empty && not backend_uses_forall in - let use_forall_use_sep = false in - extract_generic_params ctx fmt TypeDeclId.Set.empty ~use_forall - ~use_forall_use_sep ~use_arrows generics type_params cg_params - trait_clauses; - if use_forall then F.pp_print_string fmt ","; - (* Extract the inputs and output *) - F.pp_print_space fmt (); - extract_fun_inputs_output_parameters_types ctx fmt f + let f = trans.f in + let fun_name = ctx_get_trait_method decl.def_id item_name ctx in + let ty () = + (* Extract the generics *) + (* We need to add the generics specific to the method, by removing those + which actually apply to the trait decl *) + let generics = + let drop_trait_clauses = false in + generic_params_drop_prefix ~drop_trait_clauses decl.generics + f.signature.generics + in + (* Note that we do not filter the LLBC generic parameters. + This is ok because: + - we only use them to find meaningful names for the trait clauses + - we only generate trait clauses for the clauses we find in the + pure generics *) + let ctx, type_params, cg_params, trait_clauses = + ctx_add_generic_params f.llbc_name f.signature.llbc_generics generics ctx in - extract_trait_decl_item ctx fmt fun_name ty + let backend_uses_forall = + match !backend with Coq | Lean -> true | FStar | HOL4 -> false + in + let generics_not_empty = generics <> empty_generic_params in + let use_forall = generics_not_empty && backend_uses_forall in + let use_arrows = generics_not_empty && not backend_uses_forall in + let use_forall_use_sep = false in + extract_generic_params ctx fmt TypeDeclId.Set.empty ~use_forall + ~use_forall_use_sep ~use_arrows generics type_params cg_params + trait_clauses; + if use_forall then F.pp_print_string fmt ","; + (* Extract the inputs and output *) + F.pp_print_space fmt (); + extract_fun_inputs_output_parameters_types ctx fmt f in - List.iter extract_method funs + extract_trait_decl_item ctx fmt fun_name ty (** Extract a trait declaration *) let extract_trait_decl (ctx : extraction_ctx) (fmt : F.formatter) @@ -2494,21 +2389,10 @@ let extract_trait_decl_coq_arguments (ctx : extraction_ctx) (fmt : F.formatter) decl.parent_clauses; (* The required methods *) List.iter - (fun (item_name, id) -> - (* Lookup the definition *) - let trans = A.FunDeclId.Map.find id ctx.trans_funs in + (fun (item_name, _) -> (* Extract the items *) - let funs = - if trans.keep_fwd then trans.fwd :: trans.backs else trans.backs - in - let extract_for_method (f : fun_and_loops) = - let f = f.f in - let item_name = - ctx_get_trait_method decl.def_id item_name f.back_id ctx - in - extract_coq_arguments_instruction ctx fmt item_name num_params - in - List.iter extract_for_method funs) + let item_name = ctx_get_trait_method decl.def_id item_name ctx in + extract_coq_arguments_instruction ctx fmt item_name num_params) decl.required_methods; (* Add a space *) F.pp_print_space fmt ()) @@ -2531,75 +2415,71 @@ let extract_trait_impl_method_items (ctx : extraction_ctx) (fmt : F.formatter) (* Lookup the definition *) let trans = A.FunDeclId.Map.find id ctx.trans_funs in (* Extract the items *) - let funs = if trans.keep_fwd then trans.fwd :: trans.backs else trans.backs in - let extract_method (f : fun_and_loops) = - let f = f.f in - let fun_name = ctx_get_trait_method trait_decl_id item_name f.back_id ctx in - let ty () = - (* Filter the generics if the method is a builtin *) - let i_tys, _, _ = impl_generics in - let impl_types, i_tys, f_tys = - match FunDeclId.Map.find_opt f.def_id ctx.funs_filter_type_args_map with - | None -> (impl.generics.types, i_tys, f.signature.generics.types) - | Some filter -> - let filter_list filter ls = - let ls = List.combine filter ls in - List.filter_map (fun (b, ty) -> if b then Some ty else None) ls - in - let impl_types = impl.generics.types in - let impl_filter = - Collections.List.prefix (List.length impl_types) filter - in - let i_tys = i_tys in - let i_filter = Collections.List.prefix (List.length i_tys) filter in - ( filter_list impl_filter impl_types, - filter_list i_filter i_tys, - filter_list filter f.signature.generics.types ) - in - let f_generics = { f.signature.generics with types = f_tys } in - (* Extract the generics - we need to quantify over the generics which - are specific to the method, and call it will all the generics - (trait impl + method generics) *) - let f_generics = - let drop_trait_clauses = true in - generic_params_drop_prefix ~drop_trait_clauses - { impl.generics with types = impl_types } - f_generics - in - (* Register and print the quantified generics. - - Note that we do not filter the LLBC generic parameters. - This is ok because: - - we only use them to find meaningful names for the trait clauses - - we only generate trait clauses for the clauses we find in the - pure generics *) - let ctx, f_tys, f_cgs, f_tcs = - ctx_add_generic_params f.llbc_name f.signature.llbc_generics f_generics - ctx - in - let use_forall = f_generics <> empty_generic_params in - extract_generic_params ctx fmt TypeDeclId.Set.empty ~use_forall f_generics - f_tys f_cgs f_tcs; - if use_forall then F.pp_print_string fmt ","; - (* Extract the function call *) - F.pp_print_space fmt (); - let fun_name = ctx_get_local_function f.def_id None f.back_id ctx in - F.pp_print_string fmt fun_name; - let all_generics = - let _, i_cgs, i_tcs = impl_generics in - List.concat [ i_tys; f_tys; i_cgs; f_cgs; i_tcs; f_tcs ] - in - - (* Filter the generics if the function is builtin *) - List.iter - (fun p -> - F.pp_print_space fmt (); - F.pp_print_string fmt p) - all_generics + let f = trans.f in + let fun_name = ctx_get_trait_method trait_decl_id item_name ctx in + let ty () = + (* Filter the generics if the method is a builtin *) + let i_tys, _, _ = impl_generics in + let impl_types, i_tys, f_tys = + match FunDeclId.Map.find_opt f.def_id ctx.funs_filter_type_args_map with + | None -> (impl.generics.types, i_tys, f.signature.generics.types) + | Some filter -> + let filter_list filter ls = + let ls = List.combine filter ls in + List.filter_map (fun (b, ty) -> if b then Some ty else None) ls + in + let impl_types = impl.generics.types in + let impl_filter = + Collections.List.prefix (List.length impl_types) filter + in + let i_tys = i_tys in + let i_filter = Collections.List.prefix (List.length i_tys) filter in + ( filter_list impl_filter impl_types, + filter_list i_filter i_tys, + filter_list filter f.signature.generics.types ) + in + let f_generics = { f.signature.generics with types = f_tys } in + (* Extract the generics - we need to quantify over the generics which + are specific to the method, and call it will all the generics + (trait impl + method generics) *) + let f_generics = + let drop_trait_clauses = true in + generic_params_drop_prefix ~drop_trait_clauses + { impl.generics with types = impl_types } + f_generics + in + (* Register and print the quantified generics. + + Note that we do not filter the LLBC generic parameters. + This is ok because: + - we only use them to find meaningful names for the trait clauses + - we only generate trait clauses for the clauses we find in the + pure generics *) + let ctx, f_tys, f_cgs, f_tcs = + ctx_add_generic_params f.llbc_name f.signature.llbc_generics f_generics + ctx + in + let use_forall = f_generics <> empty_generic_params in + extract_generic_params ctx fmt TypeDeclId.Set.empty ~use_forall f_generics + f_tys f_cgs f_tcs; + if use_forall then F.pp_print_string fmt ","; + (* Extract the function call *) + F.pp_print_space fmt (); + let fun_name = ctx_get_local_function f.def_id None ctx in + F.pp_print_string fmt fun_name; + let all_generics = + let _, i_cgs, i_tcs = impl_generics in + List.concat [ i_tys; f_tys; i_cgs; f_cgs; i_tcs; f_tcs ] in - extract_trait_impl_item ctx fmt fun_name ty + + (* Filter the generics if the function is builtin *) + List.iter + (fun p -> + F.pp_print_space fmt (); + F.pp_print_string fmt p) + all_generics in - List.iter extract_method funs + extract_trait_impl_item ctx fmt fun_name ty (** Extract a trait implementation *) let extract_trait_impl (ctx : extraction_ctx) (fmt : F.formatter) @@ -2766,8 +2646,6 @@ let extract_trait_impl (ctx : extraction_ctx) (fmt : F.formatter) *) let extract_unit_test_if_unit_fun (ctx : extraction_ctx) (fmt : F.formatter) (def : fun_decl) : unit = - (* We only insert unit tests for forward functions *) - assert (def.back_id = None); (* Check if this is a unit function *) let sg = def.signature in if @@ -2791,9 +2669,7 @@ let extract_unit_test_if_unit_fun (ctx : extraction_ctx) (fmt : F.formatter) F.pp_print_string fmt "assert_norm"; F.pp_print_space fmt (); F.pp_print_string fmt "("; - let fun_name = - ctx_get_local_function def.def_id def.loop_id def.back_id ctx - in + let fun_name = ctx_get_local_function def.def_id def.loop_id ctx in F.pp_print_string fmt fun_name; if sg.inputs <> [] then ( F.pp_print_space fmt (); @@ -2807,9 +2683,7 @@ let extract_unit_test_if_unit_fun (ctx : extraction_ctx) (fmt : F.formatter) F.pp_print_string fmt "Check"; F.pp_print_space fmt (); F.pp_print_string fmt "("; - let fun_name = - ctx_get_local_function def.def_id def.loop_id def.back_id ctx - in + let fun_name = ctx_get_local_function def.def_id def.loop_id ctx in F.pp_print_string fmt fun_name; if sg.inputs <> [] then ( F.pp_print_space fmt (); @@ -2820,9 +2694,7 @@ let extract_unit_test_if_unit_fun (ctx : extraction_ctx) (fmt : F.formatter) F.pp_print_string fmt "#assert"; F.pp_print_space fmt (); F.pp_print_string fmt "("; - let fun_name = - ctx_get_local_function def.def_id def.loop_id def.back_id ctx - in + let fun_name = ctx_get_local_function def.def_id def.loop_id ctx in F.pp_print_string fmt fun_name; if sg.inputs <> [] then ( F.pp_print_space fmt (); @@ -2835,9 +2707,7 @@ let extract_unit_test_if_unit_fun (ctx : extraction_ctx) (fmt : F.formatter) | HOL4 -> F.pp_print_string fmt "val _ = assert_return ("; F.pp_print_string fmt "“"; - let fun_name = - ctx_get_local_function def.def_id def.loop_id def.back_id ctx - in + let fun_name = ctx_get_local_function def.def_id def.loop_id ctx in F.pp_print_string fmt fun_name; if sg.inputs <> [] then ( F.pp_print_space fmt (); diff --git a/compiler/ExtractBase.ml b/compiler/ExtractBase.ml index 5aa8323e..591e8aab 100644 --- a/compiler/ExtractBase.ml +++ b/compiler/ExtractBase.ml @@ -167,11 +167,7 @@ type id = | TraitImplId of TraitImplId.id | LocalTraitClauseId of TraitClauseId.id | TraitDeclConstructorId of TraitDeclId.id - | TraitMethodId of TraitDeclId.id * string * T.RegionGroupId.id option - (** Something peculiar with trait methods: because we have to take into - account forward/backward functions, we may need to generate fields - items per method. - *) + | TraitMethodId of TraitDeclId.id * string | TraitItemId of TraitDeclId.id * string (** A trait associated item which is not a method *) | TraitParentClauseId of TraitDeclId.id * TraitClauseId.id @@ -353,8 +349,6 @@ let basename_to_unique (names_set : StringSet.t) in if StringSet.mem basename names_set then gen 1 else basename -type fun_name_info = { keep_fwd : bool; num_backs : int } - type names_maps = { names_map : names_map; (** The map for id to names, where we forbid name collisions @@ -384,7 +378,7 @@ let allow_collisions (id : id) : bool = | FieldId _ | TraitItemClauseId _ | TraitParentClauseId _ | TraitItemId _ | TraitMethodId _ -> !Config.record_fields_short_names - | FunId (Pure _ | FromLlbc (FunId (FAssumed _), _, _)) -> + | FunId (Pure _ | FromLlbc (FunId (FAssumed _), _)) -> (* We map several assumed functions to the same id *) true | _ -> false @@ -471,8 +465,7 @@ type names_map_init = { assumed_adts : (assumed_ty * string) list; assumed_structs : (assumed_ty * string) list; assumed_variants : (assumed_ty * VariantId.id * string) list; - assumed_llbc_functions : - (A.assumed_fun_id * RegionGroupId.id option * string) list; + assumed_llbc_functions : (A.assumed_fun_id * string) list; assumed_pure_functions : (pure_assumed_fun_id * string) list; } @@ -550,15 +543,6 @@ type extraction_ctx = { -- makes the if then else dependent ]} *) - fun_name_info : fun_name_info PureUtils.RegularFunIdMap.t; - (** Information used to filter and name functions - we use it - to print comments in the generated code, to help link - the generated code to the original code (information such - as: "this function is the backward function of ...", or - "this function is the merged forward/backward function of ..." - in case a Rust function only has one backward translation - and we filter the forward function because it returns unit. - *) trait_decl_id : trait_decl_id option; (** If we are extracting a trait declaration, identifies it *) is_provided_method : bool; @@ -669,14 +653,8 @@ let id_to_string (id : id) (ctx : extraction_ctx) : string = ^ TraitClauseId.to_string clause_id | TraitItemId (id, name) -> "trait_item_id: " ^ trait_decl_id_to_string id ^ ", type name: " ^ name - | TraitMethodId (trait_decl_id, fun_name, rg_id) -> - let fwd_back_kind = - match rg_id with - | None -> "forward" - | Some rg_id -> "backward " ^ RegionGroupId.to_string rg_id - in - trait_decl_id_to_string trait_decl_id - ^ ", method name (" ^ fwd_back_kind ^ "): " ^ fun_name + | TraitMethodId (trait_decl_id, fun_name) -> + trait_decl_id_to_string trait_decl_id ^ ", method name: " ^ fun_name | TraitSelfClauseId -> "trait_self_clause" let ctx_add (id : id) (name : string) (ctx : extraction_ctx) : extraction_ctx = @@ -695,8 +673,8 @@ let ctx_get_function (id : fun_id) (ctx : extraction_ctx) : string = ctx_get (FunId id) ctx let ctx_get_local_function (id : A.FunDeclId.id) (lp : LoopId.id option) - (rg : RegionGroupId.id option) (ctx : extraction_ctx) : string = - ctx_get_function (FromLlbc (FunId (FRegular id), lp, rg)) ctx + (ctx : extraction_ctx) : string = + ctx_get_function (FromLlbc (FunId (FRegular id), lp)) ctx let ctx_get_type (id : type_id) (ctx : extraction_ctx) : string = assert (id <> TTuple); @@ -734,8 +712,8 @@ let ctx_get_trait_type (id : trait_decl_id) (item_name : string) ctx_get_trait_item id item_name ctx let ctx_get_trait_method (id : trait_decl_id) (item_name : string) - (rg_id : T.RegionGroupId.id option) (ctx : extraction_ctx) : string = - ctx_get (TraitMethodId (id, item_name, rg_id)) ctx + (ctx : extraction_ctx) : string = + ctx_get (TraitMethodId (id, item_name)) ctx let ctx_get_trait_parent_clause (id : trait_decl_id) (clause : trait_clause_id) (ctx : extraction_ctx) : string = @@ -1052,63 +1030,28 @@ let assumed_variants () : (assumed_ty * VariantId.id * string) list = (* No Fuel::Succ on purpose *) ] -let assumed_llbc_functions () : - (A.assumed_fun_id * T.RegionGroupId.id option * string) list = - let rg0 = Some T.RegionGroupId.zero in - let regular : (A.assumed_fun_id * T.RegionGroupId.id option * string) list = - match !backend with - | FStar | Coq | HOL4 -> - [ - (ArrayIndexShared, None, "array_index_usize"); - (ArrayToSliceShared, None, "array_to_slice"); - (ArrayRepeat, None, "array_repeat"); - (SliceIndexShared, None, "slice_index_usize"); - ] - | Lean -> - [ - (ArrayIndexShared, None, "Array.index_usize"); - (ArrayToSliceShared, None, "Array.to_slice"); - (ArrayRepeat, None, "Array.repeat"); - (SliceIndexShared, None, "Slice.index_usize"); - ] - in - let mut_funs : (A.assumed_fun_id * T.RegionGroupId.id option * string) list = - if !Config.return_back_funs then - match !backend with - | FStar | Coq | HOL4 -> - [ - (ArrayIndexMut, None, "array_index_mut_usize"); - (ArrayToSliceMut, None, "array_to_slice_mut"); - (SliceIndexMut, None, "slice_index_mut_usize"); - ] - | Lean -> - [ - (ArrayIndexMut, None, "Array.index_mut_usize"); - (ArrayToSliceMut, None, "Array.to_slice_mut"); - (SliceIndexMut, None, "Slice.index_mut_usize"); - ] - else - match !backend with - | FStar | Coq | HOL4 -> - [ - (ArrayIndexMut, None, "array_index_usize"); - (ArrayIndexMut, rg0, "array_update_usize"); - (ArrayToSliceMut, None, "array_to_slice"); - (ArrayToSliceMut, rg0, "array_from_slice"); - (SliceIndexMut, None, "slice_index_usize"); - (SliceIndexMut, rg0, "slice_update_usize"); - ] - | Lean -> - [ - (ArrayIndexMut, None, "Array.index_usize"); - (ArrayIndexMut, rg0, "Array.update_usize"); - (ArrayToSliceMut, None, "Array.to_slice"); - (ArrayToSliceMut, rg0, "Array.from_slice"); - (SliceIndexMut, None, "Slice.index_usize"); - (SliceIndexMut, rg0, "Slice.update_usize"); - ] - in - regular @ mut_funs +let assumed_llbc_functions () : (A.assumed_fun_id * string) list = + match !backend with + | FStar | Coq | HOL4 -> + [ + (ArrayIndexShared, "array_index_usize"); + (ArrayIndexMut, "array_index_mut_usize"); + (ArrayToSliceShared, "array_to_slice"); + (ArrayToSliceMut, "array_to_slice_mut"); + (ArrayRepeat, "array_repeat"); + (SliceIndexShared, "slice_index_usize"); + (SliceIndexMut, "slice_index_mut_usize"); + ] + | Lean -> + [ + (ArrayIndexShared, "Array.index_usize"); + (ArrayIndexMut, "Array.index_mut_usize"); + (ArrayToSliceShared, "Array.to_slice"); + (ArrayToSliceMut, "Array.to_slice_mut"); + (ArrayRepeat, "Array.repeat"); + (SliceIndexShared, "Slice.index_usize"); + (SliceIndexMut, "Slice.index_mut_usize"); + ] let assumed_pure_functions () : (pure_assumed_fun_id * string) list = match !backend with @@ -1200,8 +1143,7 @@ let initialize_names_maps () : names_maps = in let assumed_functions = List.map - (fun (fid, rg, name) -> - (FromLlbc (Pure.FunId (FAssumed fid), None, rg), name)) + (fun (fid, name) -> (FromLlbc (Pure.FunId (FAssumed fid), None), name)) init.assumed_llbc_functions @ List.map (fun (fid, name) -> (Pure fid, name)) init.assumed_pure_functions in @@ -1444,61 +1386,12 @@ let default_fun_loop_suffix (num_loops : int) (loop_id : LoopId.id option) : If this function admits only one loop, we omit it. *) if num_loops = 1 then "_loop" else "_loop" ^ LoopId.to_string loop_id -(** A helper function: generates a function suffix from a region group - information. +(** A helper function: generates a function suffix. TODO: move all those helpers. *) -let default_fun_suffix (num_loops : int) (loop_id : LoopId.id option) - (num_region_groups : int) (rg : region_group_info option) - ((keep_fwd, num_backs) : bool * int) : string = - let lp_suff = default_fun_loop_suffix num_loops loop_id in - - (* There are several cases: - - [rg] is [Some]: this is a forward function: - - we add "_fwd" - - [rg] is [None]: this is a backward function: - - this function has one extracted backward function: - - if the forward function has been filtered, we add nothing: - the forward function is useless, so the unique backward function - takes its place, in a way (in effect, we "merge" the forward - and the backward functions). - - otherwise we add "_back" - - this function has several backward functions: we add "_back" and an - additional suffix to identify the precise backward function - Note that we always add a suffix (in case there are no region groups, - we could not add the "_fwd" suffix) to prevent name clashes between - definitions (in particular between type and function definitions). - *) - let rg_suff = - (* TODO: make all the backends match what is done for Lean *) - match rg with - | None -> - if - (* In order to avoid name conflicts: - * - if the forward is eliminated, we add the suffix "_fwd" (it won't be used) - * - otherwise, no suffix (because the backward functions will have a suffix) - *) - num_backs = 1 && not keep_fwd - then "_fwd" - else "" - | Some rg -> - assert (num_region_groups > 0 && num_backs > 0); - if num_backs = 1 then - (* Exactly one backward function *) - if not keep_fwd then "" else "_back" - else if - (* Several region groups/backward functions: - - if all the regions in the group have names, we use those names - - otherwise we use an index - *) - List.for_all Option.is_some rg.region_names - then - (* Concatenate the region names *) - "_back" ^ String.concat "" (List.map Option.get rg.region_names) - else (* Use the region index *) - "_back" ^ RegionGroupId.to_string rg.id - in - lp_suff ^ rg_suff +let default_fun_suffix (num_loops : int) (loop_id : LoopId.id option) : string = + (* We only generate a suffix for the functions we generate from the loops *) + default_fun_loop_suffix num_loops loop_id (** Compute the name of a regular (non-assumed) function. @@ -1508,24 +1401,13 @@ let default_fun_suffix (num_loops : int) (loop_id : LoopId.id option) indices to derive unique names for the loops for instance - if there is exactly one loop, we don't need to use indices) - loop id (if pertinent) - - number of region groups - - region group information in case of a backward function - ([None] if forward function) - - pair: - - do we generate the forward function (it may have been filtered)? - - the number of *extracted backward functions* (same comment as for - the number of loops) - The number of extracted backward functions if not necessarily - equal to the number of region groups, because we may have - filtered some of them. TODO: use the fun id for the assumed functions. *) let ctx_compute_fun_name (ctx : extraction_ctx) (fname : llbc_name) - (num_loops : int) (loop_id : LoopId.id option) (num_rgs : int) - (rg : region_group_info option) (filter_info : bool * int) : string = + (num_loops : int) (loop_id : LoopId.id option) : string = let fname = ctx_compute_fun_name_no_suffix ctx fname in (* Compute the suffix *) - let suffix = default_fun_suffix num_loops loop_id num_rgs rg filter_info in + let suffix = default_fun_suffix num_loops loop_id in (* Concatenate *) fname ^ suffix @@ -1999,61 +1881,26 @@ let ctx_add_global_decl_and_body (def : A.global_decl) (ctx : extraction_ctx) : | None -> (* Not the case: "standard" registration *) let name = ctx_compute_global_name ctx def.name in - let body = FunId (FromLlbc (FunId (FRegular def.body), None, None)) in + let body = FunId (FromLlbc (FunId (FRegular def.body), None)) in let ctx = ctx_add decl (name ^ "_c") ctx in let ctx = ctx_add body (name ^ "_body") ctx in ctx -let ctx_compute_fun_name (trans_group : pure_fun_translation) (def : fun_decl) - (ctx : extraction_ctx) : string = - (* Lookup the LLBC def to compute the region group information *) - let def_id = def.def_id in - let llbc_def = A.FunDeclId.Map.find def_id ctx.trans_ctx.fun_ctx.fun_decls in - let sg = llbc_def.signature in - let regions_hierarchy = - LlbcAstUtils.FunIdMap.find (FRegular def_id) - ctx.trans_ctx.fun_ctx.regions_hierarchies - in - let num_rgs = List.length regions_hierarchy in - let { keep_fwd; fwd = _; backs } = trans_group in - let num_backs = List.length backs in - let rg_info = - match def.back_id with - | None -> None - | Some rg_id -> - let rg = T.RegionGroupId.nth regions_hierarchy rg_id in - let region_names = - List.map - (fun rid -> (T.RegionVarId.nth sg.generics.regions rid).name) - rg.regions - in - Some { id = rg_id; region_names } - in +let ctx_compute_fun_name (def : fun_decl) (ctx : extraction_ctx) : string = (* Add the function name *) - ctx_compute_fun_name ctx def.llbc_name def.num_loops def.loop_id num_rgs - rg_info (keep_fwd, num_backs) + ctx_compute_fun_name ctx def.llbc_name def.num_loops def.loop_id (* TODO: move to Extract *) -let ctx_add_fun_decl (trans_group : pure_fun_translation) (def : fun_decl) - (ctx : extraction_ctx) : extraction_ctx = +let ctx_add_fun_decl (def : fun_decl) (ctx : extraction_ctx) : extraction_ctx = (* Sanity check: the function should not be a global body - those are handled * separately *) assert (not def.is_global_decl_body); (* Lookup the LLBC def to compute the region group information *) let def_id = def.def_id in - let { keep_fwd; fwd = _; backs } = trans_group in - let num_backs = List.length backs in (* Add the function name *) - let def_name = ctx_compute_fun_name trans_group def ctx in - let fun_id = (Pure.FunId (FRegular def_id), def.loop_id, def.back_id) in - let ctx = ctx_add (FunId (FromLlbc fun_id)) def_name ctx in - (* Add the name info *) - { - ctx with - fun_name_info = - PureUtils.RegularFunIdMap.add fun_id { keep_fwd; num_backs } - ctx.fun_name_info; - } + let def_name = ctx_compute_fun_name def ctx in + let fun_id = (Pure.FunId (FRegular def_id), def.loop_id) in + ctx_add (FunId (FromLlbc fun_id)) def_name ctx let ctx_compute_type_decl_name (ctx : extraction_ctx) (def : type_decl) : string = diff --git a/compiler/ExtractBuiltin.ml b/compiler/ExtractBuiltin.ml index ee8d4831..88de31fe 100644 --- a/compiler/ExtractBuiltin.ml +++ b/compiler/ExtractBuiltin.ml @@ -213,11 +213,7 @@ let mk_builtin_types_map () = let builtin_types_map = mk_memoized mk_builtin_types_map -type builtin_fun_info = { - rg : Types.RegionGroupId.id option; - extract_name : string; -} -[@@deriving show] +type builtin_fun_info = { extract_name : string } [@@deriving show] (** The assumed functions. @@ -225,21 +221,11 @@ type builtin_fun_info = { parameters. For instance, in the case of the `Vec` functions, there is a type parameter for the allocator to use, which we want to filter. *) -let builtin_funs () : (pattern * bool list option * builtin_fun_info list) list - = - let rg0 = Some Types.RegionGroupId.zero in +let builtin_funs () : (pattern * bool list option * builtin_fun_info) list = (* Small utility *) let mk_fun (rust_name : string) (extract_name : string option) - (filter : bool list option) (with_back : bool) (back_no_suffix : bool) : - pattern * bool list option * builtin_fun_info list = - (* [back_no_suffix] is used to control whether the backward function should - have the suffix "_back" or not (if not, then the forward function has the - prefix "_fwd", and is filtered anyway). This is pertinent only if we split - the fwd/back functions. *) - let back_no_suffix = back_no_suffix && not !Config.return_back_funs in - (* Same for the [with_back] option: this is pertinent only if we split - the fwd/back functions *) - let with_back = with_back && not !Config.return_back_funs in + (filter : bool list option) : + pattern * bool list option * builtin_fun_info = let rust_name = try parse_pattern rust_name with Failure _ -> @@ -251,68 +237,51 @@ let builtin_funs () : (pattern * bool list option * builtin_fun_info list) list | Some name -> split_on_separator name in let basename = flatten_name extract_name in - let fwd_suffix = if with_back && back_no_suffix then "_fwd" else "" in - let fwd = [ { rg = None; extract_name = basename ^ fwd_suffix } ] in - let back_suffix = if with_back && back_no_suffix then "" else "_back" in - let back = - if with_back then [ { rg = rg0; extract_name = basename ^ back_suffix } ] - else [] - in - (rust_name, filter, fwd @ back) + let f = { extract_name = basename } in + (rust_name, filter, f) in [ - mk_fun "core::mem::replace" None None true false; + mk_fun "core::mem::replace" None None; mk_fun "core::slice::{[@T]}::len" (Some (backend_choice "slice::len" "Slice::len")) - None true false; + None; mk_fun "alloc::vec::{alloc::vec::Vec<@T, alloc::alloc::Global>}::new" - (Some "alloc::vec::Vec::new") None false false; + (Some "alloc::vec::Vec::new") None; mk_fun "alloc::vec::{alloc::vec::Vec<@T, @A>}::push" None - (Some [ true; false ]) - true true; + (Some [ true; false ]); mk_fun "alloc::vec::{alloc::vec::Vec<@T, @A>}::insert" None - (Some [ true; false ]) - true true; + (Some [ true; false ]); mk_fun "alloc::vec::{alloc::vec::Vec<@T, @A>}::len" None - (Some [ true; false ]) - true false; + (Some [ true; false ]); mk_fun "alloc::vec::{alloc::vec::Vec<@T, @A>}::index" None - (Some [ true; true; false ]) - true false; + (Some [ true; true; false ]); mk_fun "alloc::vec::{alloc::vec::Vec<@T, @A>}::index_mut" None - (Some [ true; true; false ]) - true false; - mk_fun "alloc::boxed::{Box<@T>}::deref" None - (Some [ true; false ]) - true false; - mk_fun "alloc::boxed::{Box<@T>}::deref_mut" None - (Some [ true; false ]) - true false; - mk_fun "core::slice::index::{[@T]}::index" None None true false; - mk_fun "core::slice::index::{[@T]}::index_mut" None None true false; - mk_fun "core::array::{[@T; @C]}::index" None None true false; - mk_fun "core::array::{[@T; @C]}::index_mut" None None true false; + (Some [ true; true; false ]); + mk_fun "alloc::boxed::{Box<@T>}::deref" None (Some [ true; false ]); + mk_fun "alloc::boxed::{Box<@T>}::deref_mut" None (Some [ true; false ]); + mk_fun "core::slice::index::{[@T]}::index" None None; + mk_fun "core::slice::index::{[@T]}::index_mut" None None; + mk_fun "core::array::{[@T; @C]}::index" None None; + mk_fun "core::array::{[@T; @C]}::index_mut" None None; mk_fun "core::slice::index::{core::ops::range::Range<usize>}::get" - (Some "core::slice::index::RangeUsize::get") None true false; + (Some "core::slice::index::RangeUsize::get") None; mk_fun "core::slice::index::{core::ops::range::Range<usize>}::get_mut" - (Some "core::slice::index::RangeUsize::get_mut") None true false; + (Some "core::slice::index::RangeUsize::get_mut") None; mk_fun "core::slice::index::{core::ops::range::Range<usize>}::index" - (Some "core::slice::index::RangeUsize::index") None true false; + (Some "core::slice::index::RangeUsize::index") None; mk_fun "core::slice::index::{core::ops::range::Range<usize>}::index_mut" - (Some "core::slice::index::RangeUsize::index_mut") None true false; + (Some "core::slice::index::RangeUsize::index_mut") None; mk_fun "core::slice::index::{core::ops::range::Range<usize>}::get_unchecked" - (Some "core::slice::index::RangeUsize::get_unchecked") None false false; + (Some "core::slice::index::RangeUsize::get_unchecked") None; mk_fun "core::slice::index::{core::ops::range::Range<usize>}::get_unchecked_mut" - (Some "core::slice::index::RangeUsize::get_unchecked_mut") None false - false; - mk_fun "core::slice::index::{usize}::get" None None true false; - mk_fun "core::slice::index::{usize}::get_mut" None None true false; - mk_fun "core::slice::index::{usize}::get_unchecked" None None false false; - mk_fun "core::slice::index::{usize}::get_unchecked_mut" None None false - false; - mk_fun "core::slice::index::{usize}::index" None None true false; - mk_fun "core::slice::index::{usize}::index_mut" None None true false; + (Some "core::slice::index::RangeUsize::get_unchecked_mut") None; + mk_fun "core::slice::index::{usize}::get" None None; + mk_fun "core::slice::index::{usize}::get_mut" None None; + mk_fun "core::slice::index::{usize}::get_unchecked" None None; + mk_fun "core::slice::index::{usize}::get_unchecked_mut" None None; + mk_fun "core::slice::index::{usize}::index" None None; + mk_fun "core::slice::index::{usize}::index_mut" None None; ] let mk_builtin_funs_map () = @@ -407,15 +376,14 @@ type builtin_trait_decl_info = { - a Rust name - an extraction name - a list of clauses *) - methods : (string * builtin_fun_info list) list; + methods : (string * builtin_fun_info) list; } [@@deriving show] let builtin_trait_decls_info () = - let rg0 = Some Types.RegionGroupId.zero in let mk_trait (rust_name : string) ?(extract_name : string option = None) ?(parent_clauses : string list = []) ?(types : string list = []) - ?(methods : (string * bool) list = []) () : builtin_trait_decl_info = + ?(methods : string list = []) () : builtin_trait_decl_info = let rust_name = parse_pattern rust_name in let extract_name = match extract_name with @@ -443,22 +411,14 @@ let builtin_trait_decls_info () = List.map mk_type types in let methods = - let mk_method (item_name, with_back) = + let mk_method item_name = (* TODO: factor out with builtin_funs_info *) let basename = if !record_fields_short_names then item_name else extract_name ^ "_" ^ item_name in - let back_no_suffix = false in - let fwd_suffix = if with_back && back_no_suffix then "_fwd" else "" in - let fwd = [ { rg = None; extract_name = basename ^ fwd_suffix } ] in - let back_suffix = if with_back && back_no_suffix then "" else "_back" in - let back = - if with_back then - [ { rg = rg0; extract_name = basename ^ back_suffix } ] - else [] - in - (item_name, fwd @ back) + let fwd = { extract_name = basename } in + (item_name, fwd) in List.map mk_method methods in @@ -474,21 +434,17 @@ let builtin_trait_decls_info () = in [ (* Deref *) - mk_trait "core::ops::deref::Deref" ~types:[ "Target" ] - ~methods:[ ("deref", true) ] + mk_trait "core::ops::deref::Deref" ~types:[ "Target" ] ~methods:[ "deref" ] (); (* DerefMut *) mk_trait "core::ops::deref::DerefMut" ~parent_clauses:[ "derefInst" ] - ~methods:[ ("deref_mut", true) ] - (); + ~methods:[ "deref_mut" ] (); (* Index *) - mk_trait "core::ops::index::Index" ~types:[ "Output" ] - ~methods:[ ("index", true) ] + mk_trait "core::ops::index::Index" ~types:[ "Output" ] ~methods:[ "index" ] (); (* IndexMut *) mk_trait "core::ops::index::IndexMut" ~parent_clauses:[ "indexInst" ] - ~methods:[ ("index_mut", true) ] - (); + ~methods:[ "index_mut" ] (); (* Sealed *) mk_trait "core::slice::index::private_slice_index::Sealed" (); (* SliceIndex *) @@ -496,12 +452,12 @@ let builtin_trait_decls_info () = ~types:[ "Output" ] ~methods: [ - ("get", true); - ("get_mut", true); - ("get_unchecked", false); - ("get_unchecked_mut", false); - ("index", true); - ("index_mut", true); + "get"; + "get_mut"; + "get_unchecked"; + "get_unchecked_mut"; + "index"; + "index_mut"; ] (); ] diff --git a/compiler/ExtractTypes.ml b/compiler/ExtractTypes.ml index a3dbf3cc..05b71b9f 100644 --- a/compiler/ExtractTypes.ml +++ b/compiler/ExtractTypes.ml @@ -272,7 +272,7 @@ let start_fun_decl_group (ctx : extraction_ctx) (fmt : F.formatter) if is_single_opaque_fun_decl_group dg then () else let compute_fun_def_name (def : Pure.fun_decl) : string = - ctx_get_local_function def.def_id def.loop_id def.back_id ctx ^ "_def" + ctx_get_local_function def.def_id def.loop_id ctx ^ "_def" in let names = List.map compute_fun_def_name dg in (* Add a break before *) diff --git a/compiler/Main.ml b/compiler/Main.ml index 4a2d01dc..3f5e62ad 100644 --- a/compiler/Main.ml +++ b/compiler/Main.ml @@ -72,12 +72,6 @@ let () = Arg.Symbol (backend_names, set_backend), " Specify the target backend" ); ("-dest", Arg.Set_string dest_dir, " Specify the output directory"); - ( "-no-filter-useless-calls", - Arg.Clear filter_useless_monadic_calls, - " Do not filter the useless function calls" ); - ( "-no-filter-useless-funs", - Arg.Clear filter_useless_functions, - " Do not filter the useless forward/backward functions" ); ( "-test-units", Arg.Set test_unit_functions, " Test the unit functions with the concrete (i.e., not symbolic) \ @@ -120,9 +114,6 @@ let () = " Generate a default lakefile.lean (Lean only)" ); ("-print-llbc", Arg.Set print_llbc, " Print the imported LLBC"); ("-k", Arg.Clear fail_hard, " Do not fail hard in case of error"); - ( "-split-fwd-back", - Arg.Clear return_back_funs, - " Split the forward and backward functions." ); ( "-tuple-nested-proj", Arg.Set use_nested_tuple_projectors, " Use nested projectors for tuples (e.g., (0, 1).snd.fst instead of \ diff --git a/compiler/PrintPure.ml b/compiler/PrintPure.ml index 66475d02..21ca7f08 100644 --- a/compiler/PrintPure.ml +++ b/compiler/PrintPure.ml @@ -462,21 +462,13 @@ let inst_fun_sig_to_string (env : fmt_env) (sg : inst_fun_sig) : string = let all_types = List.append inputs [ output ] in String.concat " -> " all_types -let fun_suffix (lp_id : LoopId.id option) (rg_id : T.RegionGroupId.id option) : - string = +let fun_suffix (lp_id : LoopId.id option) : string = let lp_suff = match lp_id with | None -> "" | Some lp_id -> "^loop" ^ LoopId.to_string lp_id in - - let rg_suff = - match rg_id with - | None -> "" - | Some rg_id -> "@" ^ T.RegionGroupId.to_string rg_id - in - - lp_suff ^ rg_suff + lp_suff let llbc_assumed_fun_id_to_string (fid : A.assumed_fun_id) : string = match fid with @@ -505,7 +497,7 @@ let pure_assumed_fun_id_to_string (fid : pure_assumed_fun_id) : string = let regular_fun_id_to_string (env : fmt_env) (fun_id : fun_id) : string = match fun_id with - | FromLlbc (fid, lp_id, rg_id) -> + | FromLlbc (fid, lp_id) -> let f = match fid with | FunId (FRegular fid) -> fun_decl_id_to_string env fid @@ -513,7 +505,7 @@ let regular_fun_id_to_string (env : fmt_env) (fun_id : fun_id) : string = | TraitMethod (trait_ref, method_name, _) -> trait_ref_to_string env true trait_ref ^ "." ^ method_name in - f ^ fun_suffix lp_id rg_id + f ^ fun_suffix lp_id | Pure fid -> pure_assumed_fun_id_to_string fid let unop_to_string (unop : unop) : string = @@ -746,7 +738,7 @@ and emeta_to_string (env : fmt_env) (meta : emeta) : string = let fun_decl_to_string (env : fmt_env) (def : fun_decl) : string = let env = { env with generics = def.signature.generics } in - let name = def.name ^ fun_suffix def.loop_id def.back_id in + let name = def.name ^ fun_suffix def.loop_id in let signature = fun_sig_to_string env def.signature in match def.body with | None -> "val " ^ name ^ " :\n " ^ signature diff --git a/compiler/Pure.ml b/compiler/Pure.ml index a879ba37..dd7a4acf 100644 --- a/compiler/Pure.ml +++ b/compiler/Pure.ml @@ -560,8 +560,7 @@ type fun_id_or_trait_method_ref = [@@deriving show, ord] (** A function id for a non-assumed function *) -type regular_fun_id = - fun_id_or_trait_method_ref * LoopId.id option * RegionGroupId.id option +type regular_fun_id = fun_id_or_trait_method_ref * LoopId.id option [@@deriving show, ord] (** A function identifier *) @@ -1078,7 +1077,6 @@ type fun_decl = { *) loop_id : LoopId.id option; (** [Some] if this definition was generated for a loop *) - back_id : RegionGroupId.id option; llbc_name : llbc_name; (** The original LLBC name. *) name : string; (** We use the name only for printing purposes (for debugging): diff --git a/compiler/PureMicroPasses.ml b/compiler/PureMicroPasses.ml index ec64df21..04bc90d7 100644 --- a/compiler/PureMicroPasses.ml +++ b/compiler/PureMicroPasses.ml @@ -925,156 +925,9 @@ let inline_useless_var_reassignments (ctx : trans_ctx) ~(inline_named : bool) in { def with body = Some body } -(** For the cases where we split the forward/backward functions. - - Given a forward or backward function call, is there, for every execution - path, a child backward function called later with exactly the same input - list prefix. We use this to filter useless function calls: if there are - such child calls, we can remove this one (in case its outputs are not - used). - We do this check because we can't simply remove function calls whose - outputs are not used, as they might fail. However, if a function fails, - its children backward functions then fail on the same inputs (ignoring - the additional inputs those receive). - - For instance, if we have: - {[ - fn f<'a>(x : &'a mut T); - ]} - - We often have things like this in the synthesized code: - {[ - _ <-- f@fwd x; - ... - nx <-- f@back'a x y; - ... - ]} - - If [f@back'a x y] fails, then necessarily [f@fwd x] also fails. - In this situation, we can remove the call [f@fwd x]. - *) -let expression_contains_child_call_in_all_paths (ctx : trans_ctx) - (id0 : fun_id_or_trait_method_ref) (lp_id0 : LoopId.id option) - (rg_id0 : T.RegionGroupId.id option) (generics0 : generic_args) - (args0 : texpression list) (e : texpression) : bool = - let check_call (fun_id1 : fun_or_op_id) (generics1 : generic_args) - (args1 : texpression list) : bool = - (* Check the fun_ids, to see if call1's function is a child of call0's function *) - match fun_id1 with - | Fun (FromLlbc (id1, lp_id1, rg_id1)) -> - (* Both are "regular" calls: check if they come from the same rust function *) - if id0 = id1 && lp_id0 = lp_id1 then - (* Same rust functions: check the regions hierarchy *) - let call1_is_child = - match (rg_id0, rg_id1) with - | None, _ -> - (* The function used in call0 is the forward function: the one - * used in call1 is necessarily a child *) - true - | Some _, None -> - (* Opposite of previous case *) - false - | Some rg_id0, Some rg_id1 -> - if rg_id0 = rg_id1 then true - else - (* We need to use the regions hierarchy *) - let regions_hierarchy = - let id0 = - match id0 with - | FunId fun_id -> fun_id - | TraitMethod (_, _, fun_decl_id) -> FRegular fun_decl_id - in - LlbcAstUtils.FunIdMap.find id0 - ctx.fun_ctx.regions_hierarchies - in - (* Compute the set of ancestors of the function in call1 *) - let call1_ancestors = - LlbcAstUtils.list_ancestor_region_groups regions_hierarchy - rg_id1 - in - (* Check if the function used in call0 is inside *) - T.RegionGroupId.Set.mem rg_id0 call1_ancestors - in - (* If call1 is a child, then we need to check if the input arguments - * used in call0 are a prefix of the input arguments used in call1 - * (note call1 being a child, it will likely consume strictly more - * given back values). - * *) - if call1_is_child then - let call1_args = - Collections.List.prefix (List.length args0) args1 - in - let args = List.combine args0 call1_args in - (* Note that the input values are expressions, *which may contain - * meta-values* (which we need to ignore). *) - let input_eq (v0, v1) = - PureUtils.remove_meta v0 = PureUtils.remove_meta v1 - in - (* Compare the generics and the prefix of the input arguments *) - generics0 = generics1 && List.for_all input_eq args - else (* Not a child *) - false - else (* Not the same function *) - false - | _ -> false - in - - let visitor = - object (self) - inherit [_] reduce_expression - method zero _ = false - method plus b0 b1 _ = b0 () && b1 () - - method! visit_texpression env e = - match e.e with - | Var _ | CVar _ | Const _ -> fun _ -> false - | StructUpdate _ -> - (* There shouldn't be monadic calls in structure updates - also - note that by returning [false] we are conservative: we might - *prevent* possible optimisations (i.e., filtering some function - calls), which is sound. *) - fun _ -> false - | Let (_, _, re, e) -> ( - match opt_destruct_function_call re with - | None -> fun () -> self#visit_texpression env e () - | Some (func1, generics1, args1) -> - let call_is_child = check_call func1 generics1 args1 in - if call_is_child then fun () -> true - else fun () -> self#visit_texpression env e ()) - | Lambda (_, e) -> self#visit_texpression env e - | App _ -> ( - fun () -> - match opt_destruct_function_call e with - | Some (func1, tys1, args1) -> check_call func1 tys1 args1 - | None -> false) - | Qualif _ -> - (* Note that this case includes functions without arguments *) - fun () -> false - | Meta (_, e) -> self#visit_texpression env e - | Loop loop -> - (* We only visit the *function end* *) - self#visit_texpression env loop.fun_end - | Switch (_, body) -> self#visit_switch_body env body - - method! visit_switch_body env body = - match body with - | If (e1, e2) -> - fun () -> - self#visit_texpression env e1 () - && self#visit_texpression env e2 () - | Match branches -> - fun () -> - List.for_all - (fun br -> self#visit_texpression env br.branch ()) - branches - end - in - visitor#visit_texpression () e () - (** Filter the useless assignments (removes the useless variables, filters the function calls) *) -let filter_useless (filter_monadic_calls : bool) (ctx : trans_ctx) - (def : fun_decl) : fun_decl = +let filter_useless (_ctx : trans_ctx) (def : fun_decl) : fun_decl = (* We first need a transformation on *left-values*, which filters the useless * variables and tells us whether the value contains any variable which has * not been replaced by [_] (in which case we need to keep the assignment, @@ -1166,30 +1019,8 @@ let filter_useless (filter_monadic_calls : bool) (ctx : trans_ctx) if not monadic then (* Not a monadic let-binding: simple case *) (e.e, fun _ -> used) - else - (* Monadic let-binding: trickier. - * We can filter if the right-expression is a function call, - * under some conditions. *) - match (filter_monadic_calls, opt_destruct_function_call re) with - | true, Some (Fun (FromLlbc (fid, lp_id, rg_id)), tys, args) -> - (* If we split the forward/backward functions. - - We need to check if there is a child call - see - the comments for: - [expression_contains_child_call_in_all_paths] *) - if not !Config.return_back_funs then - let has_child_call = - expression_contains_child_call_in_all_paths ctx fid - lp_id rg_id tys args e - in - if has_child_call then (* Filter *) - (e.e, fun _ -> used) - else (* No child call: don't filter *) - dont_filter () - else dont_filter () - | _ -> - (* Not an LLBC function call or not allowed to filter: we can't filter *) - dont_filter () + else (* Monadic let-binding: can't filter *) + dont_filter () else (* There are used variables: don't filter *) dont_filter () | Loop loop -> @@ -1442,22 +1273,6 @@ let simplify_aggregates (ctx : trans_ctx) (def : fun_decl) : fun_decl = let body = { body with body = body_exp } in { def with body = Some body } -(** Return [None] if the function is a backward function with no outputs (so - that we eliminate the definition which is useless). - - Note that the calls to such functions are filtered when translating from - symbolic to pure. Here, we remove the definitions altogether, because they - are now useless - *) -let filter_if_backward_with_no_outputs (def : fun_decl) : fun_decl option = - if - !Config.filter_useless_functions - && Option.is_some def.back_id - && def.signature.output = mk_result_ty mk_unit_ty - || def.signature.output = mk_unit_ty - then None - else Some def - (** Retrieve the loop definitions from the function definition. {!SymbolicToPure} generates an AST in which the loop bodies are part of @@ -1530,14 +1345,7 @@ let decompose_loops (_ctx : trans_ctx) (def : fun_decl) : info.num_inputs_with_fuel_no_state info.num_inputs_with_fuel_with_state in - let back_inputs = - if !Config.return_back_funs then [] - else - snd - (Collections.List.split_at fun_sig.inputs - info.num_inputs_with_fuel_with_state) - in - List.concat [ fuel; fwd_inputs; fwd_state; back_inputs ] + List.concat [ fuel; fwd_inputs; fwd_state ] in let output = loop.output_ty in @@ -1618,7 +1426,6 @@ let decompose_loops (_ctx : trans_ctx) (def : fun_decl) : kind = def.kind; num_loops; loop_id = Some loop.loop_id; - back_id = def.back_id; llbc_name = def.llbc_name; name = def.name; signature = loop_sig; @@ -1640,35 +1447,6 @@ let decompose_loops (_ctx : trans_ctx) (def : fun_decl) : let loops = List.map snd (LoopId.Map.bindings !loops) in (def, loops) -(** Return [false] if the forward function is useless and should be filtered. - - - a forward function with no output (comes from a Rust function with - unit return type) - - the function has mutable borrows as inputs (which is materialized - by the fact we generated backward functions which were not filtered). - - In such situation, every call to the Rust function will be translated to: - - a call to the forward function which returns nothing - - calls to the backward functions - As a failing backward function implies the forward function also fails, - we can filter the calls to the forward function, which thus becomes - useless. - In such situation, we can remove the forward function definition - altogether. - *) -let keep_forward (fwd : fun_and_loops) (backs : fun_and_loops list) : bool = - (* The question of filtering the forward functions arises only if we split - the forward/backward functions *) - if !Config.return_back_funs then true - else if - (* Note that at this point, the output types are no longer seen as tuples: - * they should be lists of length 1. *) - !Config.filter_useless_functions - && fwd.f.signature.output = mk_result_ty mk_unit_ty - && backs <> [] - then false - else true - (** Convert the unit variables to [()] if they are used as right-values or [_] if they are used as left values in patterns. *) let unit_vars_to_unit (def : fun_decl) : fun_decl = @@ -1724,19 +1502,17 @@ let eliminate_box_functions (_ctx : trans_ctx) (def : fun_decl) : fun_decl = * could have: [box_new f x]) * *) match fun_id with - | Fun (FromLlbc (FunId (FAssumed aid), _lp_id, rg_id)) -> ( - match (aid, rg_id) with - | BoxNew, _ -> - assert (rg_id = None); + | Fun (FromLlbc (FunId (FAssumed aid), _lp_id)) -> ( + match aid with + | BoxNew -> let arg, args = Collections.List.pop args in mk_apps arg args - | BoxFree, _ -> + | BoxFree -> assert (args = []); mk_unit_rvalue - | ( ( SliceIndexShared | SliceIndexMut | ArrayIndexShared - | ArrayIndexMut | ArrayToSliceShared | ArrayToSliceMut - | ArrayRepeat ), - _ ) -> + | SliceIndexShared | SliceIndexMut | ArrayIndexShared + | ArrayIndexMut | ArrayToSliceShared | ArrayToSliceMut + | ArrayRepeat -> super#visit_texpression env e) | _ -> super#visit_texpression env e) | _ -> super#visit_texpression env e @@ -1989,7 +1765,7 @@ let apply_end_passes_to_def (ctx : trans_ctx) (def : fun_decl) : fun_decl = (lazy ("eliminate_box_functions:\n\n" ^ fun_decl_to_string ctx def ^ "\n")); (* Filter the useless variables, assignments, function calls, etc. *) - let def = filter_useless !Config.filter_useless_monadic_calls ctx def in + let def = filter_useless ctx def in log#ldebug (lazy ("filter_useless:\n\n" ^ fun_decl_to_string ctx def ^ "\n")); (* Simplify the lets immediately followed by a return. @@ -2130,16 +1906,7 @@ let filter_loop_inputs (transl : pure_fun_translation list) : *) let all_decls = List.concat - (List.concat - (List.concat - (List.map - (fun { fwd; backs; _ } -> - [ fwd.f :: fwd.loops ] - :: List.map - (fun { f = back; loops = loops_back } -> - [ back :: loops_back ]) - backs) - transl))) + (List.concat (List.map (fun { f; loops } -> [ f :: loops ]) transl)) in let subgroups = ReorderDecls.group_reorder_fun_decls all_decls in @@ -2207,7 +1974,7 @@ let filter_loop_inputs (transl : pure_fun_translation list) : match e_app.e with | Qualif qualif -> ( match qualif.id with - | FunOrOp (Fun (FromLlbc (FunId fun_id', loop_id', _))) -> + | FunOrOp (Fun (FromLlbc (FunId fun_id', loop_id'))) -> if (fun_id', loop_id') = fun_id then ( (* For each argument, check if it is exactly the original input parameter. Note that there shouldn't be partial @@ -2357,8 +2124,7 @@ let filter_loop_inputs (transl : pure_fun_translation list) : match e_app.e with | Qualif qualif -> ( match qualif.id with - | FunOrOp (Fun (FromLlbc (FunId fun_id, loop_id, _))) - -> ( + | FunOrOp (Fun (FromLlbc (FunId fun_id, loop_id))) -> ( match FunLoopIdMap.find_opt (fun_id, loop_id) !used_map with @@ -2400,13 +2166,8 @@ let filter_loop_inputs (transl : pure_fun_translation list) : in let transl = List.map - (fun trans -> - let filter_fun_and_loops f = - { f = filter_in_one f.f; loops = List.map filter_in_one f.loops } - in - let fwd = filter_fun_and_loops trans.fwd in - let backs = List.map filter_fun_and_loops trans.backs in - { trans with fwd; backs }) + (fun f -> + { f = filter_in_one f.f; loops = List.map filter_in_one f.loops }) transl in @@ -2420,18 +2181,11 @@ let filter_loop_inputs (transl : pure_fun_translation list) : it thus returns the pair: (function def, loop defs). See {!decompose_loops} for more information. - Will return [None] if the function is a backward function with no outputs. - [ctx]: used only for printing. *) -let apply_passes_to_def (ctx : trans_ctx) (def : fun_decl) : - fun_and_loops option = +let apply_passes_to_def (ctx : trans_ctx) (def : fun_decl) : fun_and_loops = (* Debug *) - log#ldebug - (lazy - ("PureMicroPasses.apply_passes_to_def: " ^ def.name ^ " (" - ^ Print.option_to_string T.RegionGroupId.to_string def.back_id - ^ ")")); + log#ldebug (lazy ("PureMicroPasses.apply_passes_to_def: " ^ def.name)); log#ldebug (lazy ("original decl:\n\n" ^ fun_decl_to_string ctx def ^ "\n")); @@ -2451,29 +2205,13 @@ let apply_passes_to_def (ctx : trans_ctx) (def : fun_decl) : let def = remove_meta def in log#ldebug (lazy ("remove_meta:\n\n" ^ fun_decl_to_string ctx def ^ "\n")); - (* Remove the backward functions with no outputs. + (* Extract the loop definitions by removing the {!Loop} node *) + let def, loops = decompose_loops ctx def in - Note that the *calls* to those functions should already have been removed, - when translating from symbolic to pure. Here, we remove the definitions - altogether, because they are now useless *) - let name = def.name ^ PrintPure.fun_suffix def.loop_id def.back_id in - let opt_def = filter_if_backward_with_no_outputs def in - - match opt_def with - | None -> - log#ldebug (lazy ("filtered (backward with no outputs): " ^ name ^ "\n")); - None - | Some def -> - log#ldebug - (lazy ("not filtered (not backward with no outputs): " ^ name ^ "\n")); - - (* Extract the loop definitions by removing the {!Loop} node *) - let def, loops = decompose_loops ctx def in - - (* Apply the remaining passes *) - let f = apply_end_passes_to_def ctx def in - let loops = List.map (apply_end_passes_to_def ctx) loops in - Some { f; loops } + (* Apply the remaining passes *) + let f = apply_end_passes_to_def ctx def in + let loops = List.map (apply_end_passes_to_def ctx) loops in + { f; loops } (** Apply the micro-passes to a list of forward/backward translations. @@ -2489,18 +2227,11 @@ let apply_passes_to_def (ctx : trans_ctx) (def : fun_decl) : but convenient. *) let apply_passes_to_pure_fun_translations (ctx : trans_ctx) - (transl : (fun_decl * fun_decl list) list) : pure_fun_translation list = - let apply_to_one (trans : fun_decl * fun_decl list) : pure_fun_translation = - (* Apply the passes to the individual functions *) - let fwd, backs = trans in - let fwd = Option.get (apply_passes_to_def ctx fwd) in - let backs = List.filter_map (apply_passes_to_def ctx) backs in - (* Compute whether we need to filter the forward function or not *) - let keep_fwd = keep_forward fwd backs in - { keep_fwd; fwd; backs } - in - - let transl = List.map apply_to_one transl in + (transl : fun_decl list) : pure_fun_translation list = + (* Apply the micro-passes *) + let transl = List.map (apply_passes_to_def ctx) transl in - (* Filter the useless inputs in the loop functions *) + (* Filter the useless inputs in the loop functions (loops are initially + parameterized by *all* the symbolic values in the context, because + they may access any of them). *) filter_loop_inputs transl diff --git a/compiler/ReorderDecls.ml b/compiler/ReorderDecls.ml index 53c94ff4..f5443e03 100644 --- a/compiler/ReorderDecls.ml +++ b/compiler/ReorderDecls.ml @@ -5,11 +5,7 @@ open Pure (** The local logger *) let log = Logging.reorder_decls_log -type fun_id = { - def_id : FunDeclId.id; - lp_id : LoopId.id option; - rg_id : T.RegionGroupId.id option; -} +type fun_id = { def_id : FunDeclId.id; lp_id : LoopId.id option } [@@deriving show, ord] module FunIdOrderedType : OrderedType with type t = fun_id = struct @@ -43,11 +39,11 @@ let compute_body_fun_deps (e : texpression) : FunIdSet.t = | FunOrOp (Fun fid) -> ( match fid with | Pure _ -> () - | FromLlbc (fid, lp_id, rg_id) -> ( + | FromLlbc (fid, lp_id) -> ( match fid with | FunId (FAssumed _) -> () | TraitMethod (_, _, fid) | FunId (FRegular fid) -> - let id = { def_id = fid; lp_id; rg_id } in + let id = { def_id = fid; lp_id } in ids := FunIdSet.add id !ids)) end in @@ -71,7 +67,7 @@ let group_reorder_fun_decls (decls : fun_decl list) : (bool * fun_decl list) list = let module IntMap = MakeMap (OrderedInt) in let get_fun_id (decl : fun_decl) : fun_id = - { def_id = decl.def_id; lp_id = decl.loop_id; rg_id = decl.back_id } + { def_id = decl.def_id; lp_id = decl.loop_id } in (* Compute the list/set of identifiers *) let idl = List.map get_fun_id decls in diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml index 3a50e495..2db5f66c 100644 --- a/compiler/SymbolicToPure.ml +++ b/compiler/SymbolicToPure.ml @@ -805,11 +805,9 @@ let bs_ctx_register_forward_call (call_id : V.FunCallId.id) (forward : S.call) that we need to call. This function may be [None] if it has to be ignored (because it does nothing). *) -let bs_ctx_register_backward_call (abs : V.abs) (effect_info : fun_effect_info) - (call_id : V.FunCallId.id) (back_id : T.RegionGroupId.id) - (inherited_args : texpression list) (back_args : texpression list) - (generics : generic_args) (output_ty : ty) (ctx : bs_ctx) : - bs_ctx * texpression option = +let bs_ctx_register_backward_call (abs : V.abs) (call_id : V.FunCallId.id) + (back_id : T.RegionGroupId.id) (back_args : texpression list) (ctx : bs_ctx) + : bs_ctx * texpression option = (* Insert the abstraction in the call informations *) let info = V.FunCallId.Map.find call_id ctx.calls in let calls = V.FunCallId.Map.add call_id info ctx.calls in @@ -819,29 +817,9 @@ let bs_ctx_register_backward_call (abs : V.abs) (effect_info : fun_effect_info) let abstractions = V.AbstractionId.Map.add abs.abs_id (abs, back_args) abstractions in - (* Compute the expression corresponding to the function *) - let func = - if !Config.return_back_funs then - (* Lookup the variable introduced for the backward function *) - RegionGroupId.Map.find back_id (Option.get info.back_funs) - else - (* Retrieve the fun_id *) - let fun_id = - match info.forward.call_id with - | S.Fun (fid, _) -> - let fid = translate_fun_id_or_trait_method_ref ctx fid in - Fun (FromLlbc (fid, None, Some back_id)) - | S.Unop _ | S.Binop _ -> raise (Failure "Unreachable") - in - let args = List.append inherited_args back_args in - let input_tys = (List.map (fun (x : texpression) -> x.ty)) args in - let ret_ty = - if effect_info.can_fail then mk_result_ty output_ty else output_ty - in - let func_ty = mk_arrows input_tys ret_ty in - let func = { id = FunOrOp fun_id; generics } in - Some { e = Qualif func; ty = func_ty } - in + (* Compute the expression corresponding to the function. + We simply lookup the variable introduced for the backward function. *) + let func = RegionGroupId.Map.find back_id (Option.get info.back_funs) in (* Update the context and return *) ({ ctx with calls; abstractions }, func) @@ -1124,20 +1102,34 @@ let translate_fun_sig_with_regions_hierarchy_to_decomposed let inputs_no_state = List.map (fun ty -> (Some "ret", ty)) inputs_no_state in - (* In case we merge the forward/backward functions: - we consider the backward function as stateful and potentially failing + (* We consider a backward function as stateful and potentially failing **only if it has inputs** (for the "potentially failing": if it has not inputs, we directly evaluate it in the body of the forward function). + + For instance, we do the following: + {[ + // Rust + fn push<T, 'a>(v : &mut Vec<T>, x : T) { ... } + + (* Generated code: before doing unit elimination. + We return (), as well as the backward function; as the backward + function doesn't consume any inputs, it is a value that we compute + directly in the body of [push]. + *) + let push T (v : Vec T) (x : T) : Result (() * Vec T) = ... + + (* Generated code: after doing unit elimination, if we simplify the merged + fwd/back functions (see below). *) + let push T (v : Vec T) (x : T) : Result (Vec T) = ... + ]} *) let back_effect_info = - if !Config.return_back_funs then - let b = inputs_no_state <> [] in - { - back_effect_info with - stateful = back_effect_info.stateful && b; - can_fail = back_effect_info.can_fail && b; - } - else back_effect_info + let b = inputs_no_state <> [] in + { + back_effect_info with + stateful = back_effect_info.stateful && b; + can_fail = back_effect_info.can_fail && b; + } in let state = if back_effect_info.stateful then [ (None, mk_state_ty) ] else [] @@ -1145,8 +1137,7 @@ let translate_fun_sig_with_regions_hierarchy_to_decomposed let inputs = inputs_no_state @ state in let output_names, outputs = compute_back_outputs_for_gid gid in let filter = - !Config.simplify_merged_fwd_backs - && !Config.return_back_funs && inputs = [] && outputs = [] + !Config.simplify_merged_fwd_backs && inputs = [] && outputs = [] in let info = { @@ -1186,7 +1177,7 @@ let translate_fun_sig_with_regions_hierarchy_to_decomposed } in let ignore_output = - if !Config.return_back_funs && !Config.simplify_merged_fwd_backs then + if !Config.simplify_merged_fwd_backs then ty_is_unit fwd_output && List.exists (fun (info : back_sg_info) -> not info.filter) @@ -1296,10 +1287,10 @@ let compute_back_tys (dsg : Pure.decomposed_fun_sig) (subst : (generic_args * trait_instance_id) option) : ty option list = List.map (Option.map snd) (compute_back_tys_with_info dsg subst) -(** In case we merge the fwd/back functions: compute the output type of - a function, from a decomposed signature. *) +(** Compute the output type of a function, from a decomposed signature + (the output type contains the type of the value returned by the forward + function as well as the types of the returned backward functions). *) let compute_output_ty_from_decomposed (dsg : Pure.decomposed_fun_sig) : ty = - assert !Config.return_back_funs; (* Compute the arrow types for all the backward functions *) let back_tys = List.filter_map (fun x -> x) (compute_back_tys dsg None) in (* Group the forward output and the types of the backward functions *) @@ -1315,8 +1306,8 @@ let compute_output_ty_from_decomposed (dsg : Pure.decomposed_fun_sig) : ty = in mk_output_ty_from_effect_info effect_info output -let translate_fun_sig_from_decomposed (dsg : Pure.decomposed_fun_sig) - (gid : RegionGroupId.id option) : fun_sig = +let translate_fun_sig_from_decomposed (dsg : Pure.decomposed_fun_sig) : fun_sig + = let generics = dsg.generics in let llbc_generics = dsg.llbc_generics in let preds = dsg.preds in @@ -1329,27 +1320,10 @@ let translate_fun_sig_from_decomposed (dsg : Pure.decomposed_fun_sig) (gid, info.effect_info)) (RegionGroupId.Map.bindings dsg.back_sg)) in - let mk_output_ty = mk_output_ty_from_effect_info in let inputs, output = - (* Two cases depending on whether we split the forward/backward functions or not *) - if !Config.return_back_funs then ( - assert (gid = None); - let output = compute_output_ty_from_decomposed dsg in - let inputs = dsg.fwd_inputs in - (inputs, output)) - else - match gid with - | None -> - let effect_info = dsg.fwd_info.effect_info in - let output = mk_output_ty effect_info dsg.fwd_output in - (dsg.fwd_inputs, output) - | Some gid -> - let back_sg = RegionGroupId.Map.find gid dsg.back_sg in - let effect_info = back_sg.effect_info in - let inputs = dsg.fwd_inputs @ List.map snd back_sg.inputs in - let output = mk_simpl_tuple_ty back_sg.outputs in - let output = mk_output_ty effect_info output in - (inputs, output) + let output = compute_output_ty_from_decomposed dsg in + let inputs = dsg.fwd_inputs in + (inputs, output) in { generics; llbc_generics; preds; inputs; output; fwd_info; back_effect_info } @@ -1933,16 +1907,14 @@ and translate_panic (ctx : bs_ctx) : texpression = *) match ctx.bid with | None -> - if !Config.return_back_funs then - let back_tys = compute_back_tys ctx.sg None in - let back_tys = List.filter_map (fun x -> x) back_tys in - let tys = - if ctx.sg.fwd_info.ignore_output then back_tys - else ctx.sg.fwd_output :: back_tys - in - let output = mk_simpl_tuple_ty tys in - mk_output output - else mk_output ctx.sg.fwd_output + let back_tys = compute_back_tys ctx.sg None in + let back_tys = List.filter_map (fun x -> x) back_tys in + let tys = + if ctx.sg.fwd_info.ignore_output then back_tys + else ctx.sg.fwd_output :: back_tys + in + let output = mk_simpl_tuple_ty tys in + mk_output output | Some bid -> let output = mk_simpl_tuple_ty (RegionGroupId.Map.find bid ctx.sg.back_sg).outputs @@ -2063,7 +2035,7 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) : | S.Fun (fid, call_id) -> (* Regular function call *) let fid_t = translate_fun_id_or_trait_method_ref ctx fid in - let func = Fun (FromLlbc (fid_t, None, None)) in + let func = Fun (FromLlbc (fid_t, None)) in (* Retrieve the effect information about this function (can fail, * takes a state as input, etc.) *) let effect_info = get_fun_effect_info ctx fid None None in @@ -2080,107 +2052,103 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) : (List.concat [ fuel; args; [ state_var ] ], ctx, Some nstate_var) else (List.concat [ fuel; args ], ctx, None) in - (* If we do not split the forward/backward functions: generate the - variables for the backward functions returned by the forward + (* Generate the variables for the backward functions returned by the forward function. *) let ctx, ignore_fwd_output, back_funs_map, back_funs = - if !Config.return_back_funs then ( - (* We need to compute the signatures of the backward functions. *) - let sg = Option.get call.sg in - let decls_ctx = ctx.decls_ctx in - let dsg = - translate_fun_sig_with_regions_hierarchy_to_decomposed decls_ctx - fid call.regions_hierarchy sg - (List.map (fun _ -> None) sg.inputs) - in - log#ldebug - (lazy ("dsg.generics:\n" ^ show_generic_params dsg.generics)); - let tr_self, all_generics = - match call.trait_method_generics with - | None -> (UnknownTrait __FUNCTION__, generics) - | Some (all_generics, tr_self) -> - let all_generics = - ctx_translate_fwd_generic_args ctx all_generics - in - let tr_self = - translate_fwd_trait_instance_id ctx.type_ctx.type_infos - tr_self + (* We need to compute the signatures of the backward functions. *) + let sg = Option.get call.sg in + let decls_ctx = ctx.decls_ctx in + let dsg = + translate_fun_sig_with_regions_hierarchy_to_decomposed decls_ctx fid + call.regions_hierarchy sg + (List.map (fun _ -> None) sg.inputs) + in + log#ldebug + (lazy ("dsg.generics:\n" ^ show_generic_params dsg.generics)); + let tr_self, all_generics = + match call.trait_method_generics with + | None -> (UnknownTrait __FUNCTION__, generics) + | Some (all_generics, tr_self) -> + let all_generics = + ctx_translate_fwd_generic_args ctx all_generics + in + let tr_self = + translate_fwd_trait_instance_id ctx.type_ctx.type_infos + tr_self + in + (tr_self, all_generics) + in + let back_tys = + compute_back_tys_with_info dsg (Some (all_generics, tr_self)) + in + (* Introduce variables for the backward functions *) + (* Compute a proper basename for the variables *) + let back_fun_name = + let name = + match fid with + | FunId (FAssumed fid) -> ( + match fid with + | BoxNew -> "box_new" + | BoxFree -> "box_free" + | ArrayRepeat -> "array_repeat" + | ArrayIndexShared -> "index_shared" + | ArrayIndexMut -> "index_mut" + | ArrayToSliceShared -> "to_slice_shared" + | ArrayToSliceMut -> "to_slice_mut" + | SliceIndexShared -> "index_shared" + | SliceIndexMut -> "index_mut") + | FunId (FRegular fid) | TraitMethod (_, _, fid) -> ( + let decl = + FunDeclId.Map.find fid ctx.fun_ctx.llbc_fun_decls in - (tr_self, all_generics) + match Collections.List.last decl.name with + | PeIdent (s, _) -> s + | PeImpl _ -> + (* We shouldn't get there *) + raise (Failure "Unexpected")) in - let back_tys = - compute_back_tys_with_info dsg (Some (all_generics, tr_self)) - in - (* Introduce variables for the backward functions *) - (* Compute a proper basename for the variables *) - let back_fun_name = - let name = - match fid with - | FunId (FAssumed fid) -> ( - match fid with - | BoxNew -> "box_new" - | BoxFree -> "box_free" - | ArrayRepeat -> "array_repeat" - | ArrayIndexShared -> "index_shared" - | ArrayIndexMut -> "index_mut" - | ArrayToSliceShared -> "to_slice_shared" - | ArrayToSliceMut -> "to_slice_mut" - | SliceIndexShared -> "index_shared" - | SliceIndexMut -> "index_mut") - | FunId (FRegular fid) | TraitMethod (_, _, fid) -> ( - let decl = - FunDeclId.Map.find fid ctx.fun_ctx.llbc_fun_decls - in - match Collections.List.last decl.name with - | PeIdent (s, _) -> s - | PeImpl _ -> - (* We shouldn't get there *) - raise (Failure "Unexpected")) - in - name ^ "_back" - in - let ctx, back_vars = - fresh_opt_vars - (List.map - (fun ty -> - match ty with - | None -> None - | Some (back_sg, ty) -> - (* We insert a name for the variable only if the function - can fail: if it can fail, it means the call returns a backward - function. Otherwise, we it directly returns the value given - back by the backward function, which means we shouldn't - give it a name like "back..." (it doesn't make sense) *) - let name = - if back_sg.effect_info.can_fail then - Some back_fun_name - else None - in - Some (name, ty)) - back_tys) - ctx - in - let back_funs = - List.filter_map - (fun v -> - match v with - | None -> None - | Some v -> Some (mk_typed_pattern_from_var v None)) - back_vars - in - let gids = - List.map - (fun (g : T.region_var_group) -> g.id) - call.regions_hierarchy - in - let back_vars = - List.map (Option.map mk_texpression_from_var) back_vars - in - let back_funs_map = - RegionGroupId.Map.of_list (List.combine gids back_vars) - in - (ctx, dsg.fwd_info.ignore_output, Some back_funs_map, back_funs)) - else (ctx, false, None, []) + name ^ "_back" + in + let ctx, back_vars = + fresh_opt_vars + (List.map + (fun ty -> + match ty with + | None -> None + | Some (back_sg, ty) -> + (* We insert a name for the variable only if the function + can fail: if it can fail, it means the call returns a backward + function. Otherwise, we it directly returns the value given + back by the backward function, which means we shouldn't + give it a name like "back..." (it doesn't make sense) *) + let name = + if back_sg.effect_info.can_fail then Some back_fun_name + else None + in + Some (name, ty)) + back_tys) + ctx + in + let back_funs = + List.filter_map + (fun v -> + match v with + | None -> None + | Some v -> Some (mk_typed_pattern_from_var v None)) + back_vars + in + let gids = + List.map + (fun (g : T.region_var_group) -> g.id) + call.regions_hierarchy + in + let back_vars = + List.map (Option.map mk_texpression_from_var) back_vars + in + let back_funs_map = + RegionGroupId.Map.of_list (List.combine gids back_vars) + in + (ctx, dsg.fwd_info.ignore_output, Some back_funs_map, back_funs) in (* Compute the pattern for the destination *) let ctx, dest = fresh_var_for_symbolic_value call.dest ctx in @@ -2407,19 +2375,6 @@ and translate_end_abstraction_fun_call (ectx : C.eval_ctx) (abs : V.abs) raise (Failure "Unreachable") in let effect_info = get_fun_effect_info ctx fun_id None (Some rg_id) in - let generics = ctx_translate_fwd_generic_args ctx call.generics in - (* Retrieve the original call and the parent abstractions *) - let _forward, backwards = get_abs_ancestors ctx abs call_id in - (* Retrieve the values consumed when we called the forward function and - * ended the parent backward functions: those give us part of the input - * values (rem: for now, as we disallow nested lifetimes, there can't be - * parent backward functions). - * Note that the forward inputs **include the fuel and the input state** - * (if we use those). *) - let fwd_inputs = call_info.forward_inputs in - let back_ancestors_inputs = - List.concat (List.map (fun (_abs, args) -> args) backwards) - in (* Retrieve the values consumed upon ending the loans inside this * abstraction: those give us the remaining input values *) let back_inputs = abs_to_consumed ctx ectx abs in @@ -2434,11 +2389,6 @@ and translate_end_abstraction_fun_call (ectx : C.eval_ctx) (abs : V.abs) ([ back_state ], ctx, Some nstate) else ([], ctx, None) in - (* Concatenate all the inpus *) - let inherited_inputs = - if !Config.return_back_funs then [] - else List.concat [ fwd_inputs; back_ancestors_inputs ] - in let back_inputs = List.append back_inputs back_state in (* Retrieve the values given back by this function: those are the output * values. We rely on the fact that there are no nested borrows to use the @@ -2459,58 +2409,33 @@ and translate_end_abstraction_fun_call (ectx : C.eval_ctx) (abs : V.abs) (* Retrieve the function id, and register the function call in the context if necessary.Arith_status *) let ctx, func = - bs_ctx_register_backward_call abs effect_info call_id rg_id inherited_inputs - back_inputs generics output.ty ctx + bs_ctx_register_backward_call abs call_id rg_id back_inputs ctx in (* Translate the next expression *) let next_e = translate_expression e ctx in (* Put everything together *) - let inputs = List.append inherited_inputs back_inputs in + let inputs = back_inputs in let args_mplaces = List.map (fun _ -> None) inputs in let args = List.map (fun (arg, mp) -> mk_opt_mplace_texpression mp arg) (List.combine inputs args_mplaces) in - (* **Optimization**: - ================= - We do a small optimization here if we split the forward/backward functions. - If the backward function doesn't have any output, we don't introduce any function - call. - See the comment in {!Config.filter_useless_monadic_calls}. - - TODO: use an option to disallow backward functions from updating the state. - TODO: a backward function which only gives back shared borrows shouldn't - update the state (state updates should only be used for mutable borrows, - with objects like Rc for instance). - *) - if - (not !Config.return_back_funs) - && !Config.filter_useless_monadic_calls - && outputs = [] && nstate = None - then ( - (* No outputs - we do a small sanity check: the backward function - should have exactly the same number of inputs as the forward: - this number can be different only if the forward function returned - a value containing mutable borrows, which can't be the case... *) - assert (List.length inputs = List.length fwd_inputs); - next_e) - else - (* The backward function might also have been filtered if we do not - split the forward/backward functions *) - match func with - | None -> next_e - | Some func -> - log#ldebug - (lazy - (let args = List.map (texpression_to_string ctx) args in - "func: " - ^ texpression_to_string ctx func - ^ "\nfunc type: " - ^ pure_ty_to_string ctx func.ty - ^ "\n\nargs:\n" ^ String.concat "\n" args)); - let call = mk_apps func args in - mk_let effect_info.can_fail output call next_e + (* The backward function might have been filtered it does nothing + (consumes unit and returns unit). *) + match func with + | None -> next_e + | Some func -> + log#ldebug + (lazy + (let args = List.map (texpression_to_string ctx) args in + "func: " + ^ texpression_to_string ctx func + ^ "\nfunc type: " + ^ pure_ty_to_string ctx func.ty + ^ "\n\nargs:\n" ^ String.concat "\n" args)); + let call = mk_apps func args in + mk_let effect_info.can_fail output call next_e and translate_end_abstraction_identity (ectx : C.eval_ctx) (abs : V.abs) (e : S.expression) (ctx : bs_ctx) : texpression = @@ -2614,8 +2539,6 @@ and translate_end_abstraction_loop (ectx : C.eval_ctx) (abs : V.abs) get_fun_effect_info ctx (FunId fun_id) (Some vloop_id) (Some rg_id) in let loop_info = LoopId.Map.find loop_id ctx.loops in - let generics = loop_info.generics in - let fwd_inputs = Option.get loop_info.forward_inputs in (* Retrieve the additional backward inputs. Note that those are actually the backward inputs of the function we are synthesizing (and that we need to *transmit* to the loop backward function): they are not the @@ -2637,10 +2560,7 @@ and translate_end_abstraction_loop (ectx : C.eval_ctx) (abs : V.abs) else ([], ctx, None) in (* Concatenate all the inputs *) - let inputs = - if !Config.return_back_funs then List.concat [ back_inputs; back_state ] - else List.concat [ fwd_inputs; back_inputs; back_state ] - in + let inputs = List.concat [ back_inputs; back_state ] in (* Retrieve the values given back by this function *) let ctx, outputs = abs_to_given_back None abs ctx in (* Group the output values together: first the updated inputs *) @@ -2660,87 +2580,52 @@ and translate_end_abstraction_loop (ectx : C.eval_ctx) (abs : V.abs) (fun (arg, mp) -> mk_opt_mplace_texpression mp arg) (List.combine inputs args_mplaces) in - let input_tys = (List.map (fun (x : texpression) -> x.ty)) args in - let ret_ty = - if effect_info.can_fail then mk_result_ty output.ty else output.ty - in (* Create the expression for the function: - it is either a call to a top-level function, if we split the forward/backward functions - or a call to the variable we introduced for the backward function, if we merge the forward/backward functions *) let func = - if !Config.return_back_funs then - RegionGroupId.Map.find rg_id (Option.get loop_info.back_funs) - else - let func_ty = mk_arrows input_tys ret_ty in - let func = Fun (FromLlbc (FunId fun_id, Some loop_id, Some rg_id)) in - let func = { id = FunOrOp func; generics } in - Some { e = Qualif func; ty = func_ty } + RegionGroupId.Map.find rg_id (Option.get loop_info.back_funs) in - (* **Optimization**: - ================= - We do a small optimization here in case we split the forward/backward - functions. - If the backward function doesn't have any output, we don't introduce - any function call. - See the comment in {!Config.filter_useless_monadic_calls}. - - TODO: use an option to disallow backward functions from updating the state. - TODO: a backward function which only gives back shared borrows shouldn't - update the state (state updates should only be used for mutable borrows, - with objects like Rc for instance). - *) - if - (not !Config.return_back_funs) - && !Config.filter_useless_monadic_calls - && outputs = [] && nstate = None - then ( - (* No outputs - we do a small sanity check: the backward function - should have exactly the same number of inputs as the forward: - this number can be different only if the forward function returned - a value containing mutable borrows, which can't be the case... *) - assert (List.length inputs = List.length fwd_inputs); - next_e) - else - (* In case we merge the fwd/back functions we filter the backward - functions elsewhere *) - match func with - | None -> next_e - | Some func -> - let call = mk_apps func args in - (* Add meta-information - this is slightly hacky: we look at the - values consumed by the abstraction (note that those come from - *before* we applied the fixed-point context) and use them to - guide the naming of the output vars. - - Also, we need to convert the backward outputs from patterns to - variables. - - Finally, in practice, this works well only for loop bodies: - we do this only in this case. - TODO: improve the heuristics, to give weight to the hints for - instance. - *) - let next_e = - if ctx.inside_loop then - let consumed_values = abs_to_consumed ctx ectx abs in - let var_values = List.combine outputs consumed_values in - let var_values = - List.filter_map - (fun (var, v) -> - match var.Pure.value with - | PatVar (var, _) -> Some (var, v) - | _ -> None) - var_values - in - let vars, values = List.split var_values in - mk_emeta_symbolic_assignments vars values next_e - else next_e - in + (* We may have filtered the backward function elsewhere if it doesn't + do anything (doesn't consume anything and doesn't return anything) *) + match func with + | None -> next_e + | Some func -> + let call = mk_apps func args in + (* Add meta-information - this is slightly hacky: we look at the + values consumed by the abstraction (note that those come from + *before* we applied the fixed-point context) and use them to + guide the naming of the output vars. + + Also, we need to convert the backward outputs from patterns to + variables. + + Finally, in practice, this works well only for loop bodies: + we do this only in this case. + TODO: improve the heuristics, to give weight to the hints for + instance. + *) + let next_e = + if ctx.inside_loop then + let consumed_values = abs_to_consumed ctx ectx abs in + let var_values = List.combine outputs consumed_values in + let var_values = + List.filter_map + (fun (var, v) -> + match var.Pure.value with + | PatVar (var, _) -> Some (var, v) + | _ -> None) + var_values + in + let vars, values = List.split var_values in + mk_emeta_symbolic_assignments vars values next_e + else next_e + in - (* Create the let-binding *) - mk_let effect_info.can_fail output call next_e) + (* Create the let-binding *) + mk_let effect_info.can_fail output call next_e) and translate_global_eval (gid : A.GlobalDeclId.id) (sval : V.symbolic_value) (e : S.expression) (ctx : bs_ctx) : texpression = @@ -3068,48 +2953,40 @@ and translate_forward_end (ectx : C.eval_ctx) *) let ctx = (* Introduce variables for the inputs and the state variable - and update the context. *) - if !Config.return_back_funs then - (* If the forward/backward functions are not split, we need - to introduce fresh variables for the additional inputs, - because they are locally introduced in a lambda *) - let back_sg = RegionGroupId.Map.find bid ctx.sg.back_sg in - let ctx, backward_inputs_no_state = - fresh_vars back_sg.inputs_no_state ctx - in - let ctx, backward_inputs_with_state = - if back_sg.effect_info.stateful then - let ctx, var, _ = bs_ctx_fresh_state_var ctx in - (ctx, backward_inputs_no_state @ [ var ]) - else (ctx, backward_inputs_no_state) - in - { - ctx with - backward_inputs_no_state = - RegionGroupId.Map.add bid backward_inputs_no_state - ctx.backward_inputs_no_state; - backward_inputs_with_state = - RegionGroupId.Map.add bid backward_inputs_with_state - ctx.backward_inputs_with_state; - } - else - (* Update the state variable *) - let back_state_var = - RegionGroupId.Map.find bid ctx.back_state_vars - in - { ctx with state_var = back_state_var } + and update the context. + + We need to introduce fresh variables for the additional inputs, + because they are locally introduced in a lambda. + *) + let back_sg = RegionGroupId.Map.find bid ctx.sg.back_sg in + let ctx, backward_inputs_no_state = + fresh_vars back_sg.inputs_no_state ctx + in + let ctx, backward_inputs_with_state = + if back_sg.effect_info.stateful then + let ctx, var, _ = bs_ctx_fresh_state_var ctx in + (ctx, backward_inputs_no_state @ [ var ]) + else (ctx, backward_inputs_no_state) + in + { + ctx with + backward_inputs_no_state = + RegionGroupId.Map.add bid backward_inputs_no_state + ctx.backward_inputs_no_state; + backward_inputs_with_state = + RegionGroupId.Map.add bid backward_inputs_with_state + ctx.backward_inputs_with_state; + } in let e = T.RegionGroupId.Map.find bid back_e in let finish e = (* Wrap in lambdas if necessary *) - if !Config.return_back_funs then - let inputs = - RegionGroupId.Map.find bid ctx.backward_inputs_with_state - in - let places = List.map (fun _ -> None) inputs in - mk_lambdas_from_vars inputs places e - else e + let inputs = + RegionGroupId.Map.find bid ctx.backward_inputs_with_state + in + let places = List.map (fun _ -> None) inputs in + mk_lambdas_from_vars inputs places e in (ctx, e, finish) in @@ -3131,85 +3008,83 @@ and translate_forward_end (ectx : C.eval_ctx) function, if needs be, and lookup the proper expression. *) let translate_end ctx = - if !Config.return_back_funs then - (* Compute the output of the forward function *) - let fwd_effect_info = ctx.sg.fwd_info.effect_info in - let ctx, pure_fwd_var = fresh_var None ctx.sg.fwd_output ctx in - let fwd_e = translate_one_end ctx None in - - (* Introduce the backward functions. *) - let back_el = - List.map - (fun ((gid, _) : RegionGroupId.id * back_sg_info) -> - translate_one_end ctx (Some gid)) - (RegionGroupId.Map.bindings ctx.sg.back_sg) - in + (* Compute the output of the forward function *) + let fwd_effect_info = ctx.sg.fwd_info.effect_info in + let ctx, pure_fwd_var = fresh_var None ctx.sg.fwd_output ctx in + let fwd_e = translate_one_end ctx None in - (* Compute whether the backward expressions should be evaluated straight - away or not (i.e., if we should bind them with monadic let-bindings - or not). We evaluate them straight away if they can fail and have no - inputs. *) - let evaluate_backs = - List.map - (fun (sg : back_sg_info) -> - if !Config.simplify_merged_fwd_backs then - sg.inputs = [] && sg.effect_info.can_fail - else false) - (RegionGroupId.Map.values ctx.sg.back_sg) - in + (* Introduce the backward functions. *) + let back_el = + List.map + (fun ((gid, _) : RegionGroupId.id * back_sg_info) -> + translate_one_end ctx (Some gid)) + (RegionGroupId.Map.bindings ctx.sg.back_sg) + in - (* Introduce variables for the backward functions. - We lookup the LLBC definition in an attempt to derive pretty names - for those functions. *) - let _, back_vars = fresh_back_vars_for_current_fun ctx in + (* Compute whether the backward expressions should be evaluated straight + away or not (i.e., if we should bind them with monadic let-bindings + or not). We evaluate them straight away if they can fail and have no + inputs. *) + let evaluate_backs = + List.map + (fun (sg : back_sg_info) -> + if !Config.simplify_merged_fwd_backs then + sg.inputs = [] && sg.effect_info.can_fail + else false) + (RegionGroupId.Map.values ctx.sg.back_sg) + in - (* Create the return expressions *) - let vars = - let back_vars = List.filter_map (fun x -> x) back_vars in - if ctx.sg.fwd_info.ignore_output then back_vars - else pure_fwd_var :: back_vars - in - let vars = List.map mk_texpression_from_var vars in - let ret = mk_simpl_tuple_texpression vars in - - (* Introduce a fresh input state variable for the forward expression *) - let _ctx, state_var, state_pat = - if fwd_effect_info.stateful then - let ctx, var, pat = bs_ctx_fresh_state_var ctx in - (ctx, [ var ], [ pat ]) - else (ctx, [], []) - in + (* Introduce variables for the backward functions. + We lookup the LLBC definition in an attempt to derive pretty names + for those functions. *) + let _, back_vars = fresh_back_vars_for_current_fun ctx in - let state_var = List.map mk_texpression_from_var state_var in - let ret = mk_simpl_tuple_texpression (state_var @ [ ret ]) in - let ret = mk_result_return_texpression ret in + (* Create the return expressions *) + let vars = + let back_vars = List.filter_map (fun x -> x) back_vars in + if ctx.sg.fwd_info.ignore_output then back_vars + else pure_fwd_var :: back_vars + in + let vars = List.map mk_texpression_from_var vars in + let ret = mk_simpl_tuple_texpression vars in + + (* Introduce a fresh input state variable for the forward expression *) + let _ctx, state_var, state_pat = + if fwd_effect_info.stateful then + let ctx, var, pat = bs_ctx_fresh_state_var ctx in + (ctx, [ var ], [ pat ]) + else (ctx, [], []) + in - (* Introduce all the let-bindings *) + let state_var = List.map mk_texpression_from_var state_var in + let ret = mk_simpl_tuple_texpression (state_var @ [ ret ]) in + let ret = mk_result_return_texpression ret in - (* Combine: - - the backward variables - - whether we should evaluate the expression for the backward function - (i.e., should we use a monadic let-binding or not - we do if the - backward functions don't have inputs and can fail) - - the expressions for the backward functions - *) - let back_vars_els = - List.filter_map - (fun (v, (eval, el)) -> - match v with None -> None | Some v -> Some (v, eval, el)) - (List.combine back_vars (List.combine evaluate_backs back_el)) - in - let e = - List.fold_right - (fun (var, evaluate, back_e) e -> - mk_let evaluate (mk_typed_pattern_from_var var None) back_e e) - back_vars_els ret - in - (* Bind the expression for the forward output *) - let fwd_var = mk_typed_pattern_from_var pure_fwd_var None in - let pat = mk_simpl_tuple_pattern (state_pat @ [ fwd_var ]) in - mk_let fwd_effect_info.can_fail pat fwd_e e - else translate_one_end ctx ctx.bid + (* Introduce all the let-bindings *) + + (* Combine: + - the backward variables + - whether we should evaluate the expression for the backward function + (i.e., should we use a monadic let-binding or not - we do if the + backward functions don't have inputs and can fail) + - the expressions for the backward functions + *) + let back_vars_els = + List.filter_map + (fun (v, (eval, el)) -> + match v with None -> None | Some v -> Some (v, eval, el)) + (List.combine back_vars (List.combine evaluate_backs back_el)) + in + let e = + List.fold_right + (fun (var, evaluate, back_e) e -> + mk_let evaluate (mk_typed_pattern_from_var var None) back_e e) + back_vars_els ret + in + (* Bind the expression for the forward output *) + let fwd_var = mk_typed_pattern_from_var pure_fwd_var None in + let pat = mk_simpl_tuple_pattern (state_pat @ [ fwd_var ]) in + mk_let fwd_effect_info.can_fail pat fwd_e e in (* If we are (re-)entering a loop, we need to introduce a call to the @@ -3279,24 +3154,22 @@ and translate_forward_end (ectx : C.eval_ctx) backward functions of the outer function. *) let ctx, back_funs_map, back_funs = - if !Config.return_back_funs then - let ctx, back_vars = fresh_back_vars_for_current_fun ctx in - let back_funs = - List.filter_map - (fun v -> - match v with - | None -> None - | Some v -> Some (mk_typed_pattern_from_var v None)) - back_vars - in - let gids = RegionGroupId.Map.keys ctx.sg.back_sg in - let back_funs_map = - RegionGroupId.Map.of_list - (List.combine gids - (List.map (Option.map mk_texpression_from_var) back_vars)) - in - (ctx, Some back_funs_map, back_funs) - else (ctx, None, []) + let ctx, back_vars = fresh_back_vars_for_current_fun ctx in + let back_funs = + List.filter_map + (fun v -> + match v with + | None -> None + | Some v -> Some (mk_typed_pattern_from_var v None)) + back_vars + in + let gids = RegionGroupId.Map.keys ctx.sg.back_sg in + let back_funs_map = + RegionGroupId.Map.of_list + (List.combine gids + (List.map (Option.map mk_texpression_from_var) back_vars)) + in + (ctx, Some back_funs_map, back_funs) in (* Introduce patterns *) @@ -3339,7 +3212,7 @@ and translate_forward_end (ectx : C.eval_ctx) let out_pat = mk_simpl_tuple_pattern out_pats in let loop_call = - let fun_id = Fun (FromLlbc (FunId fid, Some loop_id, None)) in + let fun_id = Fun (FromLlbc (FunId fid, Some loop_id)) in let func = { id = FunOrOp fun_id; generics = loop_info.generics } in let input_tys = (List.map (fun (x : texpression) -> x.ty)) args in let ret_ty = @@ -3438,91 +3311,58 @@ and translate_loop (loop : S.loop) (ctx : bs_ctx) : texpression = (* The output type of the loop function *) let fwd_effect_info = { ctx.sg.fwd_info.effect_info with is_rec = true } in let back_effect_infos, output_ty = - if !Config.return_back_funs then - (* The loop backward functions consume the same additional inputs as the parent - function, but have custom outputs *) - let back_sgs = RegionGroupId.Map.bindings ctx.sg.back_sg in - let given_back_tys = RegionGroupId.Map.values rg_to_given_back_tys in - let back_info_tys = - List.map - (fun (((id, back_sg), given_back) : (_ * back_sg_info) * ty list) -> - (* Remark: the effect info of the backward function for the loop - is almost the same as for the backward function of the parent function. - Quite importantly, the fact that the function is stateful and/or can fail - mostly depends on whether it has inputs or not, and the backward functions - for the loops have the same inputs as the backward functions for the parent - function. - *) - let effect_info = back_sg.effect_info in - let effect_info = { effect_info with is_rec = true } in - (* Compute the input/output types *) - let inputs = List.map snd back_sg.inputs in - let outputs = given_back in - (* Filter if necessary *) - let ty = - if - !Config.simplify_merged_fwd_backs && inputs = [] && outputs = [] - then None - else - let output = mk_simpl_tuple_ty outputs in - let output = - mk_back_output_ty_from_effect_info effect_info inputs output - in - let ty = mk_arrows inputs output in - Some ty - in - ((id, effect_info), ty)) - (List.combine back_sgs given_back_tys) - in - let back_info = List.map fst back_info_tys in - let back_info = RegionGroupId.Map.of_list back_info in - let back_tys = List.filter_map snd back_info_tys in - let output = - if ctx.sg.fwd_info.ignore_output then back_tys - else ctx.sg.fwd_output :: back_tys - in - let output = mk_simpl_tuple_ty output in - let effect_info = ctx.sg.fwd_info.effect_info in - let output = - if effect_info.stateful then mk_simpl_tuple_ty [ mk_state_ty; output ] - else output - in - let output = - if effect_info.can_fail && inputs <> [] then mk_result_ty output - else output - in - (back_info, output) - else - let back_info = - RegionGroupId.Map.of_list - (List.map - (fun ((id, back_sg) : _ * back_sg_info) -> - (id, { back_sg.effect_info with is_rec = true })) - (RegionGroupId.Map.bindings ctx.sg.back_sg)) - in - let output = - match ctx.bid with - | None -> - (* Forward function: same type as the parent function *) - (translate_fun_sig_from_decomposed ctx.sg None).output - | Some rg_id -> - (* Backward function: custom return type *) - let doutputs = - T.RegionGroupId.Map.find rg_id rg_to_given_back_tys - in - let output = mk_simpl_tuple_ty doutputs in - let fwd_effect_info = ctx.sg.fwd_info.effect_info in - let output = - if fwd_effect_info.stateful then - mk_simpl_tuple_ty [ mk_state_ty; output ] - else output - in - let output = - if fwd_effect_info.can_fail then mk_result_ty output else output - in - output - in - (back_info, output) + (* The loop backward functions consume the same additional inputs as the parent + function, but have custom outputs *) + let back_sgs = RegionGroupId.Map.bindings ctx.sg.back_sg in + let given_back_tys = RegionGroupId.Map.values rg_to_given_back_tys in + let back_info_tys = + List.map + (fun (((id, back_sg), given_back) : (_ * back_sg_info) * ty list) -> + (* Remark: the effect info of the backward function for the loop + is almost the same as for the backward function of the parent function. + Quite importantly, the fact that the function is stateful and/or can fail + mostly depends on whether it has inputs or not, and the backward functions + for the loops have the same inputs as the backward functions for the parent + function. + *) + let effect_info = back_sg.effect_info in + let effect_info = { effect_info with is_rec = true } in + (* Compute the input/output types *) + let inputs = List.map snd back_sg.inputs in + let outputs = given_back in + (* Filter if necessary *) + let ty = + if !Config.simplify_merged_fwd_backs && inputs = [] && outputs = [] + then None + else + let output = mk_simpl_tuple_ty outputs in + let output = + mk_back_output_ty_from_effect_info effect_info inputs output + in + let ty = mk_arrows inputs output in + Some ty + in + ((id, effect_info), ty)) + (List.combine back_sgs given_back_tys) + in + let back_info = List.map fst back_info_tys in + let back_info = RegionGroupId.Map.of_list back_info in + let back_tys = List.filter_map snd back_info_tys in + let output = + if ctx.sg.fwd_info.ignore_output then back_tys + else ctx.sg.fwd_output :: back_tys + in + let output = mk_simpl_tuple_ty output in + let effect_info = ctx.sg.fwd_info.effect_info in + let output = + if effect_info.stateful then mk_simpl_tuple_ty [ mk_state_ty; output ] + else output + in + let output = + if effect_info.can_fail && inputs <> [] then mk_result_ty output + else output + in + (back_info, output) in (* Add the loop information in the context *) @@ -3708,31 +3548,26 @@ let wrap_in_match_fuel (fuel0 : VarId.id) (fuel : VarId.id) (body : texpression) let translate_fun_decl (ctx : bs_ctx) (body : S.expression option) : fun_decl = (* Translate *) let def = ctx.fun_decl in - let bid = ctx.bid in + assert (ctx.bid = None); log#ldebug (lazy ("SymbolicToPure.translate_fun_decl: " ^ name_to_string ctx def.name - ^ " (" - ^ Print.option_to_string T.RegionGroupId.to_string bid - ^ ")\n")); + ^ "\n")); (* Translate the declaration *) let def_id = def.def_id in let llbc_name = def.name in let name = name_to_string ctx llbc_name in (* Translate the signature *) - let signature = translate_fun_sig_from_decomposed ctx.sg ctx.bid in - let regions_hierarchy = - FunIdMap.find (FRegular def_id) ctx.fun_ctx.regions_hierarchies - in + let signature = translate_fun_sig_from_decomposed ctx.sg in (* Translate the body, if there is *) let body = match body with | None -> None | Some body -> let effect_info = - get_fun_effect_info ctx (FunId (FRegular def_id)) None bid + get_fun_effect_info ctx (FunId (FRegular def_id)) None None in let body = translate_expression body ctx in (* Add a match over the fuel, if necessary *) @@ -3760,37 +3595,8 @@ let translate_fun_decl (ctx : bs_ctx) (body : S.expression option) : fun_decl = if effect_info.stateful_group then [ mk_state_var ctx.state_var ] else [] in - (* Compute the list of (properly ordered) backward input variables *) - let backward_inputs : var list = - match bid with - | None -> [] - | Some back_id -> - assert (not !Config.return_back_funs); - let parents_ids = - list_ordered_ancestor_region_groups regions_hierarchy back_id - in - let backward_ids = List.append parents_ids [ back_id ] in - List.concat - (List.map - (fun id -> - T.RegionGroupId.Map.find id ctx.backward_inputs_no_state) - backward_ids) - in - (* Introduce the backward input state (the state at call site of the - * *backward* function), if necessary *) - let back_state = - if effect_info.stateful && Option.is_some bid then - let state_var = - RegionGroupId.Map.find (Option.get bid) ctx.back_state_vars - in - [ mk_state_var state_var ] - else [] - in (* Group the inputs together *) - let inputs = - List.concat - [ fuel; ctx.forward_inputs; fwd_state; backward_inputs; back_state ] - in + let inputs = List.concat [ fuel; ctx.forward_inputs; fwd_state ] in let inputs_lvs = List.map (fun v -> mk_typed_pattern_from_var v None) inputs in @@ -3799,16 +3605,10 @@ let translate_fun_decl (ctx : bs_ctx) (body : S.expression option) : fun_decl = (lazy ("SymbolicToPure.translate_fun_decl: " ^ name_to_string ctx def.name - ^ " (" - ^ Print.option_to_string T.RegionGroupId.to_string bid - ^ ")" ^ "\n- forward_inputs: " + ^ "\n- inputs: " ^ String.concat ", " (List.map show_var ctx.forward_inputs) - ^ "\n- fwd_state: " + ^ "\n- state: " ^ String.concat ", " (List.map show_var fwd_state) - ^ "\n- backward_inputs: " - ^ String.concat ", " (List.map show_var backward_inputs) - ^ "\n- back_state: " - ^ String.concat ", " (List.map show_var back_state) ^ "\n- signature.inputs: " ^ String.concat ", " (List.map (pure_ty_to_string ctx) signature.inputs))); @@ -3837,7 +3637,6 @@ let translate_fun_decl (ctx : bs_ctx) (body : S.expression option) : fun_decl = kind = def.kind; num_loops; loop_id; - back_id = bid; llbc_name; name; signature; diff --git a/compiler/Translate.ml b/compiler/Translate.ml index 55a94302..c12de045 100644 --- a/compiler/Translate.ml +++ b/compiler/Translate.ml @@ -1,5 +1,4 @@ open Interpreter -open Expressions open Types open Values open LlbcAst @@ -49,8 +48,6 @@ let translate_function_to_pure (trans_ctx : trans_ctx) log#ldebug (lazy ("translate_function_to_pure: " ^ name_to_string trans_ctx fdef.name)); - let def_id = fdef.def_id in - (* Compute the symbolic ASTs, if the function is transparent *) let symbolic_trans = translate_function_to_symbolics trans_ctx fdef in @@ -124,20 +121,7 @@ let translate_function_to_pure (trans_ctx : trans_ctx) SymbolicToPure.translate_fun_sig_from_decl_to_decomposed trans_ctx fdef in - let regions_hierarchy = - LlbcAstUtils.FunIdMap.find (FRegular def_id) fun_ctx.regions_hierarchies - in - - let var_counter, back_state_vars = - if !Config.return_back_funs then (var_counter, []) - else - List.fold_left_map - (fun var_counter (region_vars : region_var_group) -> - let gid = region_vars.id in - let var, var_counter = Pure.VarId.fresh var_counter in - (var_counter, (gid, var))) - var_counter regions_hierarchy - in + let var_counter, back_state_vars = (var_counter, []) in let back_state_vars = RegionGroupId.Map.of_list back_state_vars in let ctx = @@ -195,28 +179,7 @@ let translate_function_to_pure (trans_ctx : trans_ctx) in (* Add the backward inputs *) - let ctx, backward_inputs_no_state, backward_inputs_with_state = - if !Config.return_back_funs then (ctx, [], []) - else - let ctx, inputs_no_with_state = - List.fold_left_map - (fun ctx (region_vars : region_var_group) -> - let gid = region_vars.id in - let back_sg = RegionGroupId.Map.find gid sg.back_sg in - let ctx, no_state = - SymbolicToPure.fresh_vars back_sg.inputs_no_state ctx - in - let ctx, with_state = - SymbolicToPure.fresh_vars back_sg.inputs ctx - in - (ctx, ((gid, no_state), (gid, with_state)))) - ctx regions_hierarchy - in - let inputs_no_state, inputs_with_state = - List.split inputs_no_with_state - in - (ctx, inputs_no_state, inputs_with_state) - in + let backward_inputs_no_state, backward_inputs_with_state = ([], []) in let backward_inputs_no_state = RegionGroupId.Map.of_list backward_inputs_no_state in @@ -225,40 +188,10 @@ let translate_function_to_pure (trans_ctx : trans_ctx) in let ctx = { ctx with backward_inputs_no_state; backward_inputs_with_state } in - (* Translate the forward function *) - let pure_forward = - match symbolic_trans with - | None -> SymbolicToPure.translate_fun_decl ctx None - | Some (_, ast) -> SymbolicToPure.translate_fun_decl ctx (Some ast) - in - - (* Translate the backward functions, if we split the forward and backward functions *) - let translate_backward (rg : region_var_group) : Pure.fun_decl = - (* For the backward inputs/outputs initialization: we use the fact that - * there are no nested borrows for now, and so that the region groups - * can't have parents *) - assert (rg.parents = []); - let back_id = rg.id in - - match symbolic_trans with - | None -> - (* Initialize the context *) - let ctx = { ctx with bid = Some back_id } in - (* Translate *) - SymbolicToPure.translate_fun_decl ctx None - | Some (_, symbolic) -> - (* Initialize the context *) - let ctx = { ctx with bid = Some back_id } in - (* Translate *) - SymbolicToPure.translate_fun_decl ctx (Some symbolic) - in - let pure_backwards = - if !Config.return_back_funs then [] - else List.map translate_backward regions_hierarchy - in - - (* Return *) - (pure_forward, pure_backwards) + (* Translate the function *) + match symbolic_trans with + | None -> SymbolicToPure.translate_fun_decl ctx None + | Some (_, ast) -> SymbolicToPure.translate_fun_decl ctx (Some ast) (* TODO: factor out the return type *) let translate_crate_to_pure (crate : crate) : @@ -513,9 +446,8 @@ let export_global (fmt : Format.formatter) (config : gen_config) (ctx : gen_ctx) let global_decls = ctx.trans_ctx.global_ctx.global_decls in let global = GlobalDeclId.Map.find id global_decls in let trans = FunDeclId.Map.find global.body ctx.trans_funs in - assert (trans.fwd.loops = []); - assert (trans.backs = []); - let body = trans.fwd.f in + assert (trans.loops = []); + let body = trans.f in let is_opaque = Option.is_none body.Pure.body in (* Check if we extract the global *) @@ -643,7 +575,7 @@ let export_functions_group (fmt : Format.formatter) (config : gen_config) let funs_map = builtin_funs_map () in List.map (fun (trans : pure_fun_translation) -> - match_name_find_opt ctx.trans_ctx trans.fwd.f.llbc_name funs_map <> None) + match_name_find_opt ctx.trans_ctx trans.f.llbc_name funs_map <> None) pure_ls in @@ -660,7 +592,7 @@ let export_functions_group (fmt : Format.formatter) (config : gen_config) (* Extract the decrease clauses template bodies *) if config.extract_template_decreases_clauses then List.iter - (fun { fwd; _ } -> + (fun f -> (* We only generate decreases clauses for the forward functions, because the termination argument should only depend on the forward inputs. The backward functions thus use the same decreases clauses as the @@ -687,27 +619,14 @@ let export_functions_group (fmt : Format.formatter) (config : gen_config) raise (Failure "HOL4 doesn't have decreases/termination clauses") in - extract_decrease fwd.f; - List.iter extract_decrease fwd.loops) + extract_decrease f.f; + List.iter extract_decrease f.loops) pure_ls; - (* Concatenate the function definitions, filtering the useless forward - * functions. *) + (* Flatten the translated functions (concatenate the functions with + the declarations introduced for the loops) *) let decls = - List.concat - (List.map - (fun { keep_fwd; fwd; backs } -> - let fwd = - if keep_fwd then List.append fwd.loops [ fwd.f ] else [] - in - let backs : Pure.fun_decl list = - List.concat - (List.map - (fun back -> List.append back.loops [ back.f ]) - backs) - in - List.append fwd backs) - pure_ls) + List.concat (List.map (fun f -> List.append f.loops [ f.f ]) pure_ls) in (* Extract the function definitions *) @@ -724,9 +643,7 @@ let export_functions_group (fmt : Format.formatter) (config : gen_config) (* Insert unit tests if necessary *) if config.test_trans_unit_functions then List.iter - (fun trans -> - if trans.keep_fwd then - Extract.extract_unit_test_if_unit_fun ctx fmt trans.fwd.f) + (fun trans -> Extract.extract_unit_test_if_unit_fun ctx fmt trans.f) pure_ls (** Export a trait declaration. *) @@ -812,7 +729,7 @@ let extract_definitions (fmt : Format.formatter) (config : gen_config) extract their type directly in the records we generate for the trait declarations themselves, there is no point in having separate type definitions) *) - match pure_fun.fwd.f.Pure.kind with + match pure_fun.f.Pure.kind with | TraitMethodDecl _ -> () | _ -> (* Translate *) @@ -1001,18 +918,18 @@ let translate_crate (filename : string) (dest_dir : string) (crate : crate) : * whether we should generate a decrease clause or not. *) let rec_functions = List.map - (fun { fwd; _ } -> - let fwd_f = - if fwd.f.Pure.signature.fwd_info.effect_info.is_rec then - [ (fwd.f.def_id, None) ] + (fun trans -> + let f = + if trans.f.Pure.signature.fwd_info.effect_info.is_rec then + [ (trans.f.def_id, None) ] else [] in - let loop_fwds = + let loops = List.map (fun (def : Pure.fun_decl) -> [ (def.def_id, def.loop_id) ]) - fwd.loops + trans.loops in - fwd_f :: loop_fwds) + f :: loops) trans_funs in let rec_functions : PureUtils.fun_loop_id list = @@ -1028,7 +945,7 @@ let translate_crate (filename : string) (dest_dir : string) (crate : crate) : let trans_funs : pure_fun_translation FunDeclId.Map.t = FunDeclId.Map.of_list (List.map - (fun (trans : pure_fun_translation) -> (trans.fwd.f.def_id, trans)) + (fun (trans : pure_fun_translation) -> (trans.f.def_id, trans)) trans_funs) in @@ -1052,7 +969,6 @@ let translate_crate (filename : string) (dest_dir : string) (crate : crate) : names_maps; indent_incr = 2; use_dep_ite = !Config.backend = Lean && !Config.extract_decreases_clauses; - fun_name_info = PureUtils.RegularFunIdMap.empty; trait_decl_id = None (* None by default *); is_provided_method = false (* false by default *); trans_trait_decls; @@ -1082,7 +998,6 @@ let translate_crate (filename : string) (dest_dir : string) (crate : crate) : (fun ctx (trans : pure_fun_translation) -> (* If requested by the user, register termination measures and decreases proofs for all the recursive functions *) - let fwd_def = trans.fwd.f in let gen_decr_clause (def : Pure.fun_decl) = !Config.extract_decreases_clauses && PureUtils.FunLoopIdSet.mem @@ -1091,7 +1006,7 @@ let translate_crate (filename : string) (dest_dir : string) (crate : crate) : in (* Register the names, only if the function is not a global body - * those are handled later *) - let is_global = fwd_def.Pure.is_global_decl_body in + let is_global = trans.f.Pure.is_global_decl_body in if is_global then ctx else Extract.extract_fun_decl_register_names ctx gen_decr_clause trans) ctx @@ -1171,13 +1086,7 @@ let translate_crate (filename : string) (dest_dir : string) (crate : crate) : let exe_dir = Filename.dirname Sys.argv.(0) in let primitives_src_dest = match !Config.backend with - | FStar -> - let src = - if !Config.return_back_funs then - "/backends/fstar/merge/Primitives.fst" - else "/backends/fstar/split/Primitives.fst" - in - Some (src, "Primitives.fst") + | FStar -> Some ("/backends/fstar/merge/Primitives.fst", "Primitives.fst") | Coq -> Some ("/backends/coq/Primitives.v", "Primitives.v") | Lean -> None | HOL4 -> None diff --git a/compiler/TranslateCore.ml b/compiler/TranslateCore.ml index 88438872..05877b5a 100644 --- a/compiler/TranslateCore.ml +++ b/compiler/TranslateCore.ml @@ -8,19 +8,8 @@ let log = Logging.translate_log type trans_ctx = decls_ctx [@@deriving show] type fun_and_loops = { f : Pure.fun_decl; loops : Pure.fun_decl list } -type pure_fun_translation_no_loops = Pure.fun_decl * Pure.fun_decl list - -type pure_fun_translation = { - keep_fwd : bool; - (** Should we extract the forward function? - - If the forward function returns `()` and there is exactly one - backward function, we may merge the forward into the backward - function and thus don't extract the forward function)? - *) - fwd : fun_and_loops; - backs : fun_and_loops list; -} +type pure_fun_translation_no_loops = Pure.fun_decl +type pure_fun_translation = fun_and_loops let trans_ctx_to_fmt_env (ctx : trans_ctx) : Print.fmt_env = Print.Contexts.decls_ctx_to_fmt_env ctx |