diff options
author | Son HO | 2023-11-22 15:06:43 +0100 |
---|---|---|
committer | GitHub | 2023-11-22 15:06:43 +0100 |
commit | bacf3f5f6f5f6a9aa650d5ae8d12a132fd747039 (patch) | |
tree | 9953d7af1fe406cdc750030a43a5e4d6245cd763 /compiler/Translate.ml | |
parent | 587f1ebc0178acb19029d3fc9a729c197082aba7 (diff) | |
parent | 01cfd899119174ef7c5941c99dd251711f4ee701 (diff) |
Merge pull request #45 from AeneasVerif/son_merge_types
Big cleanup
Diffstat (limited to '')
-rw-r--r-- | compiler/Translate.ml | 204 |
1 files changed, 98 insertions, 106 deletions
diff --git a/compiler/Translate.ml b/compiler/Translate.ml index a3d96023..05e48af5 100644 --- a/compiler/Translate.ml +++ b/compiler/Translate.ml @@ -1,11 +1,11 @@ -open InterpreterStatements open Interpreter -module L = Logging -module T = Types -module A = LlbcAst +open Expressions +open Types +open Values +open LlbcAst +open Contexts module SA = SymbolicAst module Micro = PureMicroPasses -module C = Contexts open PureUtils open TranslateCore @@ -16,18 +16,17 @@ let log = TranslateCore.log - the list of symbolic values used for the input values - the generated symbolic AST *) -type symbolic_fun_translation = V.symbolic_value list * SA.expression +type symbolic_fun_translation = symbolic_value list * SA.expression (** Execute the symbolic interpreter on a function to generate a list of symbolic ASTs, for the forward function and the backward functions. *) -let translate_function_to_symbolics (trans_ctx : trans_ctx) (fdef : A.fun_decl) - : symbolic_fun_translation option = +let translate_function_to_symbolics (trans_ctx : trans_ctx) (fdef : fun_decl) : + symbolic_fun_translation option = (* Debug *) log#ldebug (lazy - ("translate_function_to_symbolics: " - ^ Print.fun_name_to_string fdef.A.name)); + ("translate_function_to_symbolics: " ^ name_to_string trans_ctx fdef.name)); match fdef.body with | None -> None @@ -45,12 +44,11 @@ let translate_function_to_symbolics (trans_ctx : trans_ctx) (fdef : A.fun_decl) *) let translate_function_to_pure (trans_ctx : trans_ctx) (fun_sigs : SymbolicToPure.fun_sig_named_outputs RegularFunIdNotLoopMap.t) - (pure_type_decls : Pure.type_decl Pure.TypeDeclId.Map.t) (fdef : A.fun_decl) - : pure_fun_translation_no_loops = + (pure_type_decls : Pure.type_decl Pure.TypeDeclId.Map.t) (fdef : fun_decl) : + pure_fun_translation_no_loops = (* Debug *) log#ldebug - (lazy - ("translate_function_to_pure: " ^ Print.fun_name_to_string fdef.A.name)); + (lazy ("translate_function_to_pure: " ^ name_to_string trans_ctx fdef.name)); let def_id = fdef.def_id in @@ -61,22 +59,24 @@ let translate_function_to_pure (trans_ctx : trans_ctx) (* Initialize the context *) let forward_sig = - RegularFunIdNotLoopMap.find (E.Regular def_id, None) fun_sigs + RegularFunIdNotLoopMap.find (FRegular def_id, None) fun_sigs in - let sv_to_var = V.SymbolicValueId.Map.empty in + let sv_to_var = SymbolicValueId.Map.empty in let var_counter = Pure.VarId.generator_zero in let state_var, var_counter = Pure.VarId.fresh var_counter in let back_state_var, var_counter = Pure.VarId.fresh var_counter in let fuel0, var_counter = Pure.VarId.fresh var_counter in let fuel, var_counter = Pure.VarId.fresh var_counter in - let calls = V.FunCallId.Map.empty in - let abstractions = V.AbstractionId.Map.empty in + let calls = FunCallId.Map.empty in + let abstractions = AbstractionId.Map.empty in let recursive_type_decls = - T.TypeDeclId.Set.of_list + TypeDeclId.Set.of_list (List.filter_map (fun (tid, g) -> - match g with Charon.GAst.NonRec _ -> None | Rec _ -> Some tid) - (T.TypeDeclId.Map.bindings trans_ctx.type_ctx.type_decls_groups)) + match g with + | Charon.GAst.NonRecGroup _ -> None + | RecGroup _ -> Some tid) + (TypeDeclId.Map.bindings trans_ctx.type_ctx.type_decls_groups)) in let type_context = { @@ -91,6 +91,7 @@ let translate_function_to_pure (trans_ctx : trans_ctx) SymbolicToPure.llbc_fun_decls = trans_ctx.fun_ctx.fun_decls; fun_sigs; fun_infos = trans_ctx.fun_ctx.fun_infos; + regions_hierarchies = trans_ctx.fun_ctx.regions_hierarchies; } in let global_context = @@ -103,9 +104,9 @@ let translate_function_to_pure (trans_ctx : trans_ctx) *) let loop_ids_map = match symbolic_trans with - | None -> V.LoopId.Map.empty + | None -> LoopId.Map.empty | Some (_, ast) -> - let m = ref V.LoopId.Map.empty in + let m = ref LoopId.Map.empty in let _, fresh_loop_id = Pure.LoopId.fresh_stateful_generator () in let visitor = @@ -114,10 +115,9 @@ let translate_function_to_pure (trans_ctx : trans_ctx) method! visit_loop env loop = let _ = - match V.LoopId.Map.find_opt loop.loop_id !m with + match LoopId.Map.find_opt loop.loop_id !m with | Some _ -> () - | None -> - m := V.LoopId.Map.add loop.loop_id (fresh_loop_id ()) !m + | None -> m := LoopId.Map.add loop.loop_id (fresh_loop_id ()) !m in super#visit_loop env loop end @@ -147,9 +147,9 @@ let translate_function_to_pure (trans_ctx : trans_ctx) fun_decl = fdef; forward_inputs = []; (* Empty for now *) - backward_inputs = T.RegionGroupId.Map.empty; + backward_inputs = RegionGroupId.Map.empty; (* Empty for now *) - backward_outputs = T.RegionGroupId.Map.empty; + backward_outputs = RegionGroupId.Map.empty; loop_backward_outputs = None; (* Empty for now *) calls; @@ -170,7 +170,7 @@ let translate_function_to_pure (trans_ctx : trans_ctx) | Some body, Some (input_svs, _) -> let forward_input_vars = LlbcAstUtils.fun_body_get_input_vars body in let forward_input_varnames = - List.map (fun (v : A.var) -> v.name) forward_input_vars + List.map (fun (v : var) -> v.name) forward_input_vars in let input_svs = List.combine forward_input_varnames input_svs in let ctx, forward_inputs = @@ -188,7 +188,7 @@ let translate_function_to_pure (trans_ctx : trans_ctx) in (* Translate the backward functions *) - let translate_backward (rg : T.region_var_group) : Pure.fun_decl = + let translate_backward (rg : region_group) : Pure.fun_decl = (* For the backward inputs/outputs initialization: we use the fact that * there are no nested borrows for now, and so that the region groups * can't have parents *) @@ -200,7 +200,7 @@ let translate_function_to_pure (trans_ctx : trans_ctx) (* Initialize the context - note that the ret_ty is not really * useful as we don't translate a body *) let backward_sg = - RegularFunIdNotLoopMap.find (Regular def_id, Some back_id) fun_sigs + RegularFunIdNotLoopMap.find (FRegular def_id, Some back_id) fun_sigs in let ctx = { ctx with bid = Some back_id; sg = backward_sg.sg } in @@ -211,7 +211,7 @@ let translate_function_to_pure (trans_ctx : trans_ctx) variables required by the backward function. *) let backward_sg = - RegularFunIdNotLoopMap.find (Regular def_id, Some back_id) fun_sigs + RegularFunIdNotLoopMap.find (FRegular def_id, Some back_id) fun_sigs in (* We need to ignore the forward inputs, and the state input (if there is) *) let backward_inputs = @@ -243,10 +243,10 @@ let translate_function_to_pure (trans_ctx : trans_ctx) SymbolicToPure.fresh_vars backward_outputs ctx in let backward_inputs = - T.RegionGroupId.Map.singleton back_id backward_inputs + RegionGroupId.Map.singleton back_id backward_inputs in let backward_outputs = - T.RegionGroupId.Map.singleton back_id backward_outputs + RegionGroupId.Map.singleton back_id backward_outputs in (* Put everything in the context *) @@ -263,15 +263,17 @@ let translate_function_to_pure (trans_ctx : trans_ctx) (* Translate *) SymbolicToPure.translate_fun_decl ctx (Some symbolic) in - let pure_backwards = - List.map translate_backward fdef.signature.regions_hierarchy + let regions_hierarchy = + LlbcAstUtils.FunIdMap.find (FRegular fdef.def_id) + fun_context.regions_hierarchies in + let pure_backwards = List.map translate_backward regions_hierarchy in (* Return *) (pure_forward, pure_backwards) (* TODO: factor out the return type *) -let translate_crate_to_pure (crate : A.crate) : +let translate_crate_to_pure (crate : crate) : trans_ctx * Pure.type_decl list * pure_fun_translation list @@ -284,9 +286,7 @@ let translate_crate_to_pure (crate : A.crate) : let trans_ctx = compute_contexts crate in (* Translate all the type definitions *) - let type_decls = - SymbolicToPure.translate_type_decls (T.TypeDeclId.Map.values crate.types) - in + let type_decls = SymbolicToPure.translate_type_decls trans_ctx in (* Compute the type definition map *) let type_decls_map = @@ -298,24 +298,24 @@ let translate_crate_to_pure (crate : A.crate) : let assumed_sigs = List.map (fun (info : Assumed.assumed_fun_info) -> - ( E.Assumed info.fun_id, + ( FAssumed info.fun_id, List.map (fun _ -> None) info.fun_sig.inputs, info.fun_sig )) Assumed.assumed_fun_infos in let local_sigs = List.map - (fun (fdef : A.fun_decl) -> + (fun (fdef : fun_decl) -> let input_names = match fdef.body with | None -> List.map (fun _ -> None) fdef.signature.inputs | Some body -> List.map - (fun (v : A.var) -> v.name) + (fun (v : var) -> v.name) (LlbcAstUtils.fun_body_get_input_vars body) in - (E.Regular fdef.def_id, input_names, fdef.signature)) - (A.FunDeclId.Map.values crate.functions) + (FRegular fdef.def_id, input_names, fdef.signature)) + (FunDeclId.Map.values crate.fun_decls) in let sigs = List.append assumed_sigs local_sigs in let fun_sigs = SymbolicToPure.translate_fun_signatures trans_ctx sigs in @@ -324,22 +324,21 @@ let translate_crate_to_pure (crate : A.crate) : let pure_translations = List.map (translate_function_to_pure trans_ctx fun_sigs type_decls_map) - (A.FunDeclId.Map.values crate.functions) + (FunDeclId.Map.values crate.fun_decls) in (* Translate the trait declarations *) - let type_infos = trans_ctx.type_ctx.type_infos in let trait_decls = List.map - (SymbolicToPure.translate_trait_decl type_infos) - (T.TraitDeclId.Map.values trans_ctx.trait_decls_ctx.trait_decls) + (SymbolicToPure.translate_trait_decl trans_ctx) + (TraitDeclId.Map.values trans_ctx.trait_decls_ctx.trait_decls) in (* Translate the trait implementations *) let trait_impls = List.map - (SymbolicToPure.translate_trait_impl type_infos) - (T.TraitImplId.Map.values trans_ctx.trait_impls_ctx.trait_impls) + (SymbolicToPure.translate_trait_impl trans_ctx) + (TraitImplId.Map.values trans_ctx.trait_impls_ctx.trait_impls) in (* Apply the micro-passes *) @@ -398,9 +397,9 @@ let crate_has_opaque_non_builtin_decls (ctx : gen_ctx) (filter_assumed : bool) : log#ldebug (lazy ("Opaque decls:" ^ "\n- types:\n" - ^ String.concat ",\n" (List.map T.show_type_decl types) + ^ String.concat ",\n" (List.map show_type_decl types) ^ "\n- functions:\n" - ^ String.concat ",\n" (List.map A.show_fun_decl funs))); + ^ String.concat ",\n" (List.map show_fun_decl funs))); (types <> [], funs <> []) (** Export a type declaration. @@ -478,8 +477,7 @@ let export_types_group (fmt : Format.formatter) (config : gen_config) let types_map = builtin_types_map () in List.map (fun (def : Pure.type_decl) -> - let sname = name_to_simple_name def.name in - SimpleNameMap.find_opt sname types_map <> None) + match_name_find_opt ctx.trans_ctx def.llbc_name types_map <> None) defs in @@ -528,10 +526,10 @@ let export_types_group (fmt : Format.formatter) (config : gen_config) TODO: check correct behavior with opaque globals. *) let export_global (fmt : Format.formatter) (config : gen_config) (ctx : gen_ctx) - (id : A.GlobalDeclId.id) : unit = + (id : GlobalDeclId.id) : unit = let global_decls = ctx.trans_ctx.global_ctx.global_decls in - let global = A.GlobalDeclId.Map.find id global_decls in - let trans = A.FunDeclId.Map.find global.body_id ctx.trans_funs in + let global = GlobalDeclId.Map.find id global_decls in + let trans = FunDeclId.Map.find global.body ctx.trans_funs in assert (trans.fwd.loops = []); assert (trans.backs = []); let body = trans.fwd.f in @@ -546,9 +544,9 @@ let export_global (fmt : Format.formatter) (config : gen_config) (ctx : gen_ctx) (* Check if it is a builtin global - if yes, we ignore it because we map the definition to one in the standard library *) let open ExtractBuiltin in - let sname = name_to_simple_name global.name in let extract = - extract && SimpleNameMap.find_opt sname builtin_globals_map = None + extract + && match_name_find_opt ctx.trans_ctx global.name builtin_globals_map = None in if extract then (* We don't wrap global declaration groups between calls to functions @@ -662,8 +660,7 @@ let export_functions_group (fmt : Format.formatter) (config : gen_config) let funs_map = builtin_funs_map () in List.map (fun (trans : pure_fun_translation) -> - let sname = name_to_simple_name trans.fwd.f.basename in - SimpleNameMap.find_opt sname funs_map <> None) + match_name_find_opt ctx.trans_ctx trans.fwd.f.llbc_name funs_map <> None) pure_ls in @@ -753,11 +750,14 @@ let export_functions_group (fmt : Format.formatter) (config : gen_config) let export_trait_decl (fmt : Format.formatter) (_config : gen_config) (ctx : gen_ctx) (trait_decl_id : Pure.trait_decl_id) (extract_decl : bool) (extract_extra_info : bool) : unit = - let trait_decl = T.TraitDeclId.Map.find trait_decl_id ctx.trans_trait_decls in + let trait_decl = TraitDeclId.Map.find trait_decl_id ctx.trans_trait_decls in (* Check if the trait declaration is builtin, in which case we ignore it *) let open ExtractBuiltin in - let sname = name_to_simple_name trait_decl.name in - if SimpleNameMap.find_opt sname (builtin_trait_decls_map ()) = None then ( + if + match_name_find_opt ctx.trans_ctx trait_decl.llbc_name + (builtin_trait_decls_map ()) + = None + then ( let ctx = { ctx with trait_decl_id = Some trait_decl.def_id } in if extract_decl then Extract.extract_trait_decl ctx fmt trait_decl; if extract_extra_info then @@ -768,7 +768,7 @@ let export_trait_decl (fmt : Format.formatter) (_config : gen_config) let export_trait_impl (fmt : Format.formatter) (_config : gen_config) (ctx : gen_ctx) (trait_impl_id : Pure.trait_impl_id) : unit = (* Lookup the definition *) - let trait_impl = T.TraitImplId.Map.find trait_impl_id ctx.trans_trait_impls in + let trait_impl = TraitImplId.Map.find trait_impl_id ctx.trans_trait_impls in let trait_decl = Pure.TraitDeclId.Map.find trait_impl.impl_trait.trait_decl_id ctx.trans_trait_decls @@ -776,9 +776,11 @@ let export_trait_impl (fmt : Format.formatter) (_config : gen_config) (* Check if the trait implementation is builtin *) let builtin_info = let open ExtractBuiltin in - let type_sname = name_to_simple_name trait_impl.name in - let trait_sname = name_to_simple_name trait_decl.name in - SimpleNamePairMap.find_opt (type_sname, trait_sname) + let trait_impl = + TraitImplId.Map.find trait_impl.def_id ctx.crate.trait_impls + in + match_name_with_generics_find_opt ctx.trans_ctx trait_decl.llbc_name + trait_impl.impl_trait.decl_generics (builtin_trait_impls_map ()) in match builtin_info with @@ -814,14 +816,15 @@ let extract_definitions (fmt : Format.formatter) (config : gen_config) Extract.extract_state_type fmt ctx kind in - let export_decl_group (dg : A.declaration_group) : unit = + let export_decl_group (dg : declaration_group) : unit = match dg with - | Type (NonRec id) -> + | TypeGroup (NonRecGroup id) -> if config.extract_types then export_types_group false [ id ] - | Type (Rec ids) -> if config.extract_types then export_types_group true ids - | Fun (NonRec id) -> ( + | TypeGroup (RecGroup ids) -> + if config.extract_types then export_types_group true ids + | FunGroup (NonRecGroup id) -> ( (* Lookup *) - let pure_fun = A.FunDeclId.Map.find id ctx.trans_funs in + let pure_fun = FunDeclId.Map.find id ctx.trans_funs in (* Special case: we skip trait method *declarations* (we will extract their type directly in the records we generate for the trait declarations themselves, there is no point in having @@ -831,21 +834,21 @@ let extract_definitions (fmt : Format.formatter) (config : gen_config) | _ -> (* Translate *) export_functions_group [ pure_fun ]) - | Fun (Rec ids) -> + | FunGroup (RecGroup ids) -> (* General case of mutually recursive functions *) (* Lookup *) let pure_funs = - List.map (fun id -> A.FunDeclId.Map.find id ctx.trans_funs) ids + List.map (fun id -> FunDeclId.Map.find id ctx.trans_funs) ids in (* Translate *) export_functions_group pure_funs - | Global id -> export_global id - | TraitDecl id -> + | GlobalGroup id -> export_global id + | TraitDeclGroup id -> (* TODO: update to extract groups *) if config.extract_trait_decls && config.extract_transparent then ( export_trait_decl_group id; export_trait_decl_group_extra_info id) - | TraitImpl id -> + | TraitImplGroup id -> if config.extract_trait_impls && config.extract_transparent then export_trait_impl id in @@ -983,27 +986,17 @@ let extract_file (config : gen_config) (ctx : gen_ctx) (fi : extract_file_info) close_out out (** Translate a crate and write the synthesized code to an output file. *) -let translate_crate (filename : string) (dest_dir : string) (crate : A.crate) : +let translate_crate (filename : string) (dest_dir : string) (crate : crate) : unit = (* Translate the module to the pure AST *) let trans_ctx, trans_types, trans_funs, trans_trait_decls, trans_trait_impls = translate_crate_to_pure crate in - (* Initialize the extraction context - for now we extract only to F*. - * We initialize the names map by registering the keywords used in the - * language, as well as some primitive names ("u32", etc.) *) - let variant_concatenate_type_name = - (* For Lean, we exploit the fact that the variant name should always be - prefixed with the type name to prevent collisions *) - match !Config.backend with Coq | FStar | HOL4 -> true | Lean -> false - in - (* Initialize the names map (we insert the names of the "primitives" - declarations, and insert the names of the local declarations later) *) - let fmt, names_maps = - Extract.mk_formatter_and_names_maps trans_ctx crate.name - variant_concatenate_type_name - in + (* Initialize the names map by registering the keywords used in the + language, as well as some primitive names ("u32", etc.). + We insert the names of the local declarations later. *) + let names_maps = Extract.initialize_names_maps () in (* We need to compute which functions are recursive, in order to know * whether we should generate a decrease clause or not. *) @@ -1033,8 +1026,8 @@ let translate_crate (filename : string) (dest_dir : string) (crate : A.crate) : Pure.TypeDeclId.Map.of_list (List.map (fun (d : Pure.type_decl) -> (d.def_id, d)) trans_types) in - let trans_funs : pure_fun_translation A.FunDeclId.Map.t = - A.FunDeclId.Map.of_list + let trans_funs : pure_fun_translation FunDeclId.Map.t = + FunDeclId.Map.of_list (List.map (fun (trans : pure_fun_translation) -> (trans.fwd.f.def_id, trans)) trans_funs) @@ -1043,13 +1036,13 @@ let translate_crate (filename : string) (dest_dir : string) (crate : A.crate) : (* Put everything in the context *) let ctx = let trans_trait_decls = - T.TraitDeclId.Map.of_list + TraitDeclId.Map.of_list (List.map (fun (d : Pure.trait_decl) -> (d.def_id, d)) trans_trait_decls) in let trans_trait_impls = - T.TraitImplId.Map.of_list + TraitImplId.Map.of_list (List.map (fun (d : Pure.trait_impl) -> (d.def_id, d)) trans_trait_impls) @@ -1058,7 +1051,6 @@ let translate_crate (filename : string) (dest_dir : string) (crate : A.crate) : ExtractBase.crate; trans_ctx; names_maps; - fmt; indent_incr = 2; use_dep_ite = !Config.backend = Lean && !Config.extract_decreases_clauses; fun_name_info = PureUtils.RegularFunIdMap.empty; @@ -1104,12 +1096,12 @@ let translate_crate (filename : string) (dest_dir : string) (crate : A.crate) : if is_global then ctx else Extract.extract_fun_decl_register_names ctx gen_decr_clause trans) ctx - (A.FunDeclId.Map.values trans_funs) + (FunDeclId.Map.values trans_funs) in let ctx = List.fold_left Extract.extract_global_decl_register_names ctx - (A.GlobalDeclId.Map.values crate.globals) + (GlobalDeclId.Map.values crate.global_decls) in let ctx = @@ -1288,7 +1280,7 @@ let translate_crate (filename : string) (dest_dir : string) (crate : A.crate) : namespace; in_namespace = true; crate_name; - rust_module_name = crate.A.name; + rust_module_name = crate.name; module_name = types_module; custom_msg = ": type definitions"; custom_imports = []; @@ -1316,7 +1308,7 @@ let translate_crate (filename : string) (dest_dir : string) (crate : A.crate) : namespace; in_namespace = true; crate_name; - rust_module_name = crate.A.name; + rust_module_name = crate.name; module_name = template_clauses_module; custom_msg = ": templates for the decreases clauses"; custom_imports = [ types_module ]; @@ -1366,7 +1358,7 @@ let translate_crate (filename : string) (dest_dir : string) (crate : A.crate) : namespace; in_namespace = false; crate_name; - rust_module_name = crate.A.name; + rust_module_name = crate.name; module_name = opaque_module; custom_msg; custom_imports = []; @@ -1405,7 +1397,7 @@ let translate_crate (filename : string) (dest_dir : string) (crate : A.crate) : namespace; in_namespace = true; crate_name; - rust_module_name = crate.A.name; + rust_module_name = crate.name; module_name = fun_module; custom_msg = ": function definitions"; custom_imports = []; @@ -1438,7 +1430,7 @@ let translate_crate (filename : string) (dest_dir : string) (crate : A.crate) : namespace; in_namespace = true; crate_name; - rust_module_name = crate.A.name; + rust_module_name = crate.name; module_name = crate_name; custom_msg = ""; custom_imports = []; |