From 6c88d30031255c0ac612b67bb5b3c20c2f07e563 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Mon, 13 Nov 2023 13:27:02 +0100 Subject: Add RegionsHierarchy.ml --- compiler/AssociatedTypes.ml | 5 +- compiler/Assumed.ml | 18 +-- compiler/Contexts.ml | 2 + compiler/ExtractBase.ml | 8 +- compiler/Interpreter.ml | 28 +++-- compiler/InterpreterStatements.ml | 2 +- compiler/InterpreterUtils.ml | 9 +- compiler/LlbcAstUtils.ml | 13 ++ compiler/PureMicroPasses.ml | 9 +- compiler/RegionsHierarchy.ml | 243 ++++++++++++++++++++++++++++++++++++++ compiler/ReorderDecls.ml | 97 +-------------- compiler/SCC.ml | 121 +++++++++++++++++-- compiler/Substitute.ml | 12 +- compiler/SymbolicToPure.ml | 25 ++-- compiler/Translate.ml | 7 +- compiler/dune | 1 + 16 files changed, 453 insertions(+), 147 deletions(-) create mode 100644 compiler/RegionsHierarchy.ml diff --git a/compiler/AssociatedTypes.ml b/compiler/AssociatedTypes.ml index d5c9596e..27e08495 100644 --- a/compiler/AssociatedTypes.ml +++ b/compiler/AssociatedTypes.ml @@ -521,10 +521,11 @@ let ctx_subst_norm_signature (ctx : C.eval_ctx) (ty_subst : T.TypeVarId.id -> T.ty) (cg_subst : T.ConstGenericVarId.id -> T.const_generic) (tr_subst : T.TraitClauseId.id -> T.trait_instance_id) - (tr_self : T.trait_instance_id) (sg : A.fun_sig) : A.inst_fun_sig = + (tr_self : T.trait_instance_id) (sg : A.fun_sig) + (regions_hierarchy : T.region_groups) : A.inst_fun_sig = let sg = Subst.substitute_signature asubst r_subst ty_subst cg_subst tr_subst tr_self - sg + sg regions_hierarchy in let { A.regions_hierarchy; inputs; output; trait_type_constraints } = sg in let inputs = List.map (ctx_normalize_ty ctx) inputs in diff --git a/compiler/Assumed.ml b/compiler/Assumed.ml index cf81502c..aa0cfccf 100644 --- a/compiler/Assumed.ml +++ b/compiler/Assumed.ml @@ -79,7 +79,7 @@ module Sig = struct let mk_slice_ty (ty : T.ty) : T.ty = TAdt (TAssumed TSlice, mk_generic_args [] [ ty ] []) - let mk_sig generics regions_hierarchy inputs output : A.fun_sig = + let mk_sig generics inputs output : A.fun_sig = let preds : T.predicates = { regions_outlive = []; types_outlive = []; trait_type_constraints = [] } in @@ -88,7 +88,6 @@ module Sig = struct generics; preds; parent_params_info = None; - regions_hierarchy; inputs; output; } @@ -96,18 +95,16 @@ module Sig = struct (** [fn(T) -> Box] *) let box_new_sig : A.fun_sig = let generics = mk_generic_params [] [ type_param_0 ] [] (* *) in - let regions_hierarchy = [] in let inputs = [ tvar_0 (* T *) ] in let output = mk_box_ty tvar_0 (* Box *) in - mk_sig generics regions_hierarchy inputs output + mk_sig generics inputs output (** [fn(Box) -> ()] *) let box_free_sig : A.fun_sig = let generics = mk_generic_params [] [ type_param_0 ] [] (* *) in - let regions_hierarchy = [] in let inputs = [ mk_box_ty tvar_0 (* Box *) ] in let output = mk_unit_ty (* () *) in - mk_sig generics regions_hierarchy inputs output + mk_sig generics inputs output (** Array/slice functions *) @@ -129,7 +126,6 @@ module Sig = struct let generics = mk_generic_params [ region_param_0 ] [ type_param_0 ] cgs (* <'a, T> *) in - let regions_hierarchy = [ region_group_0 ] (* <'a> *) in let inputs = [ mk_ref_ty rvar_0 @@ -145,7 +141,7 @@ module Sig = struct (output_ty type_param_0.index) is_mut (* &'a (mut) output_ty *) in - mk_sig generics regions_hierarchy inputs output + mk_sig generics inputs output let mk_array_slice_index_sig (is_array : bool) (is_mut : bool) : A.fun_sig = (* Array *) @@ -176,13 +172,12 @@ module Sig = struct (* *) mk_generic_params [] [ type_param_0 ] [ cg_param_0 ] in - let regions_hierarchy = [] (* <> *) in let inputs = [ tvar_0 (* T *) ] in let output = (* [T; N] *) mk_array_ty tvar_0 cgvar_0 in - mk_sig generics regions_hierarchy inputs output + mk_sig generics inputs output (** Helper: [fn(&'a [T]) -> usize] @@ -191,12 +186,11 @@ module Sig = struct let generics = mk_generic_params [ region_param_0 ] [ type_param_0 ] [] (* <'a, T> *) in - let regions_hierarchy = [ region_group_0 ] (* <'a> *) in let inputs = [ mk_ref_ty rvar_0 (mk_slice_ty tvar_0) false (* &'a [T] *) ] in let output = mk_usize_ty (* usize *) in - mk_sig generics regions_hierarchy inputs output + mk_sig generics inputs output end type raw_assumed_fun_info = diff --git a/compiler/Contexts.ml b/compiler/Contexts.ml index 9a20a6cc..12927aab 100644 --- a/compiler/Contexts.ml +++ b/compiler/Contexts.ml @@ -2,6 +2,7 @@ open Types open Expressions open Values open LlbcAst +open LlbcAstUtils module V = Values open ValuesUtils open Identifiers @@ -190,6 +191,7 @@ type type_context = { type fun_context = { fun_decls : fun_decl FunDeclId.Map.t; fun_infos : FunsAnalysis.fun_info FunDeclId.Map.t; + regions_hierarchies : T.region_groups FunIdMap.t; } [@@deriving show] diff --git a/compiler/ExtractBase.ml b/compiler/ExtractBase.ml index d5eac6f4..4ce1e9f1 100644 --- a/compiler/ExtractBase.ml +++ b/compiler/ExtractBase.ml @@ -1188,14 +1188,18 @@ let ctx_compute_fun_name (trans_group : pure_fun_translation) (def : fun_decl) let def_id = def.def_id in let llbc_def = A.FunDeclId.Map.find def_id ctx.trans_ctx.fun_ctx.fun_decls in let sg = llbc_def.signature in - let num_rgs = List.length sg.regions_hierarchy in + let regions_hierarchy = + LlbcAstUtils.FunIdMap.find (FRegular def_id) + ctx.trans_ctx.fun_ctx.regions_hierarchies + in + let num_rgs = List.length regions_hierarchy in let { keep_fwd; fwd = _; backs } = trans_group in let num_backs = List.length backs in let rg_info = match def.back_id with | None -> None | Some rg_id -> - let rg = T.RegionGroupId.nth sg.regions_hierarchy rg_id in + let rg = T.RegionGroupId.nth regions_hierarchy rg_id in let region_names = List.map (fun rid -> (T.RegionId.nth sg.generics.regions rid).name) diff --git a/compiler/Interpreter.ml b/compiler/Interpreter.ml index 395c0c80..69c9af62 100644 --- a/compiler/Interpreter.ml +++ b/compiler/Interpreter.ml @@ -29,7 +29,10 @@ let compute_contexts (m : A.crate) : C.decls_ctx = let fun_infos = FunsAnalysis.analyze_module m fun_decls global_decls !Config.use_state in - let fun_ctx = { C.fun_decls; fun_infos } in + let regions_hierarchies = + RegionsHierarchy.compute_regions_hierarchies type_decls fun_decls + in + let fun_ctx = { C.fun_decls; fun_infos; regions_hierarchies } in let global_ctx = { C.global_decls } in let trait_decls_ctx = { C.trait_decls } in let trait_impls_ctx = { C.trait_impls } in @@ -124,8 +127,8 @@ let symbolic_instantiate_fun_sig (ctx : C.eval_ctx) (sg : A.fun_sig) List.fold_left_map (fun tr_map (c : T.trait_clause) -> let subst = mk_subst tr_map in - let { T.trait_id = trait_decl_id; generics; _ } = c in - let generics = Subst.generic_args_substitute subst generics in + let { T.trait_id = trait_decl_id; clause_generics; _ } = c in + let generics = Subst.generic_args_substitute subst clause_generics in let trait_decl_ref = { T.trait_decl_id; decl_generics = generics } in (* Note that because we directly refer to the clause, we give it empty generics *) @@ -183,8 +186,11 @@ let initialize_symbolic_context_for_fun (ctx : C.decls_ctx) (fdef : A.fun_decl) * *) let sg = fdef.signature in (* Create the context *) + let regions_hierarchy = + FunIdMap.find (FRegular fdef.def_id) ctx.fun_ctx.regions_hierarchies + in let region_groups = - List.map (fun (g : T.region_group) -> g.id) sg.regions_hierarchy + List.map (fun (g : T.region_group) -> g.id) regions_hierarchy in let ctx = initialize_eval_context ctx region_groups sg.generics.types @@ -269,7 +275,6 @@ let evaluate_function_symbolic_synthesize_backward_from_return * the return type. Note that it is important to re-generate * an instantiation of the signature, so that we use fresh * region ids for the return abstractions. *) - let sg = fdef.signature in let _, ret_inst_sg = symbolic_instantiate_fun_sig ctx fdef.signature fdef.kind in @@ -282,7 +287,10 @@ let evaluate_function_symbolic_synthesize_backward_from_return * will end - this will allow us to, first, mark the other return * regions as non-endable, and, second, end those parent regions in * proper order. *) - let parent_rgs = list_ancestor_region_groups sg back_id in + let regions_hierarchy = + FunIdMap.find (FRegular fdef.def_id) ctx.fun_context.regions_hierarchies + in + let parent_rgs = list_ancestor_region_groups regions_hierarchy back_id in let parent_input_abs_ids = T.RegionGroupId.mapi (fun rg_id rg -> @@ -455,6 +463,10 @@ let evaluate_function_symbolic (synthesize : bool) (ctx : C.decls_ctx) (* Create the evaluation context *) let ctx, input_svs, inst_sg = initialize_symbolic_context_for_fun ctx fdef in + let regions_hierarchy = + FunIdMap.find (FRegular fdef.def_id) ctx.fun_context.regions_hierarchies + in + (* Create the continuation to finish the evaluation *) let config = C.mk_config C.SymbolicMode in let cf_finish res ctx = @@ -511,7 +523,7 @@ let evaluate_function_symbolic (synthesize : bool) (ctx : C.decls_ctx) let back_el = T.RegionGroupId.mapi (fun gid _ -> (gid, finish_back_eval gid)) - fdef.signature.regions_hierarchy + regions_hierarchy in let back_el = T.RegionGroupId.Map.of_list back_el in (* Put everything together *) @@ -555,7 +567,7 @@ let evaluate_function_symbolic (synthesize : bool) (ctx : C.decls_ctx) let back_el = T.RegionGroupId.mapi (fun gid _ -> (gid, finish_back_eval gid)) - fdef.signature.regions_hierarchy + regions_hierarchy in let back_el = T.RegionGroupId.Map.of_list back_el in (* Put everything together *) diff --git a/compiler/InterpreterStatements.ml b/compiler/InterpreterStatements.ml index cbc09c29..3627490d 100644 --- a/compiler/InterpreterStatements.ml +++ b/compiler/InterpreterStatements.ml @@ -1434,7 +1434,7 @@ and eval_assumed_function_call_symbolic (config : C.config) let inst_sig = match fid with | BoxFree -> - (* should have been treated above *) + (* Should have been treated above *) raise (Failure "Unreachable") | _ -> (* There shouldn't be any reference to Self *) diff --git a/compiler/InterpreterUtils.ml b/compiler/InterpreterUtils.ml index 60747b2a..5e2746c5 100644 --- a/compiler/InterpreterUtils.ml +++ b/compiler/InterpreterUtils.ml @@ -474,6 +474,11 @@ let instantiate_fun_sig (ctx : C.eval_ctx) (generics : T.generic_args) (* Erase the regions in the generics we use for the instantiation *) let generics = Subst.generic_args_erase_regions generics in let tr_self = Subst.trait_instance_id_erase_regions tr_self in + (* Compute the regions hierarchy *) + let regions_hierarchy = + RegionsHierarchy.compute_regions_hierarchy_for_sig + ctx.type_context.type_decls sg + in (* Generate fresh abstraction ids and create a substitution from region * group ids to abstraction ids *) let rg_abs_ids_bindings = @@ -481,7 +486,7 @@ let instantiate_fun_sig (ctx : C.eval_ctx) (generics : T.generic_args) (fun rg -> let abs_id = C.fresh_abstraction_id () in (rg.T.id, abs_id)) - sg.regions_hierarchy + regions_hierarchy in let asubst_map : V.AbstractionId.id T.RegionGroupId.Map.t = List.fold_left @@ -512,7 +517,7 @@ let instantiate_fun_sig (ctx : C.eval_ctx) (generics : T.generic_args) (* Substitute the signature *) let inst_sig = AssociatedTypes.ctx_subst_norm_signature ctx asubst rsubst tsubst cgsubst - tr_subst tr_self sg + tr_subst tr_self sg regions_hierarchy in (* Return *) inst_sig diff --git a/compiler/LlbcAstUtils.ml b/compiler/LlbcAstUtils.ml index 46b36851..de46320b 100644 --- a/compiler/LlbcAstUtils.ml +++ b/compiler/LlbcAstUtils.ml @@ -1,5 +1,18 @@ open LlbcAst include Charon.LlbcAstUtils +open Collections + +module FunIdOrderedType : OrderedType with type t = fun_id = struct + type t = fun_id + + let compare = compare_fun_id + let to_string = show_fun_id + let pp_t = pp_fun_id + let show_t = show_fun_id +end + +module FunIdMap = Collections.MakeMap (FunIdOrderedType) +module FunIdSet = Collections.MakeSet (FunIdOrderedType) let lookup_fun_sig (fun_id : fun_id) (fun_decls : fun_decl FunDeclId.Map.t) : fun_sig = diff --git a/compiler/PureMicroPasses.ml b/compiler/PureMicroPasses.ml index 8872571f..d2747a4b 100644 --- a/compiler/PureMicroPasses.ml +++ b/compiler/PureMicroPasses.ml @@ -786,18 +786,19 @@ let expression_contains_child_call_in_all_paths (ctx : trans_ctx) if rg_id0 = rg_id1 then true else (* We need to use the regions hierarchy *) - (* First, lookup the signature of the LLBC function *) - let sg = + let regions_hierarchy = let id0 = match id0 with | FunId fun_id -> fun_id | TraitMethod (_, _, fun_decl_id) -> FRegular fun_decl_id in - LlbcAstUtils.lookup_fun_sig id0 ctx.fun_ctx.fun_decls + LlbcAstUtils.FunIdMap.find id0 + ctx.fun_ctx.regions_hierarchies in (* Compute the set of ancestors of the function in call1 *) let call1_ancestors = - LlbcAstUtils.list_ancestor_region_groups sg rg_id1 + LlbcAstUtils.list_ancestor_region_groups regions_hierarchy + rg_id1 in (* Check if the function used in call0 is inside *) T.RegionGroupId.Set.mem rg_id0 call1_ancestors diff --git a/compiler/RegionsHierarchy.ml b/compiler/RegionsHierarchy.ml new file mode 100644 index 00000000..dd566426 --- /dev/null +++ b/compiler/RegionsHierarchy.ml @@ -0,0 +1,243 @@ +(** This module analyzes function signatures to compute the + hierarchy between regions. + + Note that we don't need to analyze the types: when there is a non-trivial + relation between lifetimes in a type definition, the Rust compiler will + automatically introduce the relevant where clauses. For instance, in the + definition below: + + {[ + struct Wrapper<'a, 'b, T> { + x : &'a mut &'b mut T, + } + ]} + + the Rust compiler will introduce the where clauses: + {[ + 'b : 'a + T : 'b + ]} + + However, it doesn't do so for the function signatures, which means we have + to compute the constraints between the lifetimes ourselves, then that we + have to compute the SCCs of the lifetimes (two lifetimes 'a and 'b may + satisfy 'a : 'b and 'b : 'a, meaning they are actually equal and should + be grouped together). + *) + +open Types +open TypesUtils +open Expressions +open LlbcAst +open LlbcAstUtils +open Assumed +open SCC +module Subst = Substitute + +let compute_regions_hierarchy_for_sig (type_decls : type_decl TypeDeclId.Map.t) + (sg : fun_sig) : region_groups = + (* Create the dependency graph. + + An edge from 'short to 'long means that 'long outlives 'short (that is + we have 'long : 'short, using Rust notations). + *) + (* First initialize the regions map. + + We add: + - the region variables + - the static region + - edges from the region variables to the static region + *) + let g : RegionSet.t RegionMap.t ref = + let s_set = RegionSet.singleton RStatic in + let m = + List.map + (fun (r : region_var) -> (RVar r.index, s_set)) + sg.generics.regions + in + let s = (RStatic, RegionSet.empty) in + ref (RegionMap.of_list (s :: m)) + in + + let add_edge ~(short : region) ~(long : region) = + let m = !g in + let s = RegionMap.find short !g in + let s = RegionSet.add long s in + g := RegionMap.add short s m + in + + let add_edge_from_region_constraint ((long, short) : region_outlives) = + add_edge ~short ~long + in + + let add_edges ~(long : region) ~(shorts : region list) = + List.iter (fun short -> add_edge ~short ~long) shorts + in + + (* Explore the clauses - we only explore the "region outlives" clause, + not the "type outlives" clauses *) + List.iter add_edge_from_region_constraint sg.preds.regions_outlive; + + (* Explore the types in the signature to add the edges *) + let rec explore_ty (outer : region list) (ty : ty) = + match ty with + | TAdt (id, generics) -> + (* Add constraints coming from the type clauses *) + (match id with + | TAdtId id -> + (* Lookup the type declaration *) + let decl = TypeDeclId.Map.find id type_decls in + (* Instantiate the predicates *) + let tr_self = + UnknownTrait ("Unexpected, introduced by " ^ __FUNCTION__) + in + let subst = + Subst.make_subst_from_generics decl.generics generics tr_self + in + let predicates = Subst.predicates_substitute subst decl.preds in + (* Note that because we also explore the generics below, we may + explore several times the same type - this is ok *) + List.iter + (fun (long, short) -> add_edges ~long ~shorts:(short :: outer)) + predicates.regions_outlive; + List.iter + (fun (ty, short) -> explore_ty (short :: outer) ty) + predicates.types_outlive + | TTuple -> (* No clauses for tuples *) () + | TAssumed aid -> ( + match aid with + | TBox | TArray | TSlice | TStr -> (* No clauses for those *) ())); + (* Explore the generics *) + explore_generics outer generics + | TVar _ | TLiteral _ | TNever -> () + | TRef (r, ty, _) -> + (* Add the constraints for r *) + add_edges ~long:r ~shorts:outer; + (* Add r to the outer regions *) + let outer = r :: outer in + (* Continue *) + explore_ty outer ty + | TRawPtr (ty, _) -> explore_ty outer ty + | TTraitType (trait_ref, _generic_args, _) -> + (* The trait should reference a clause, and not an implementation + (otherwise it should have been normalized) *) + assert ( + AssociatedTypes.trait_instance_id_is_local_clause trait_ref.trait_id); + (* We have nothing to do *) + () + | TArrow (inputs, output) -> + (* TODO: this may be too constraining *) + List.iter (explore_ty outer) (output :: inputs) + and explore_generics (outer : region list) (generics : generic_args) = + let { regions; types; const_generics = _; trait_refs = _ } = generics in + List.iter (fun long -> add_edges ~long ~shorts:outer) regions; + List.iter (explore_ty outer) types + in + + List.iter (explore_ty []) (sg.output :: sg.inputs); + + (* Compute the ordered SCCs *) + let module Scc = SCC.Make (RegionOrderedType) in + let sccs = Scc.compute (RegionMap.bindings !g) in + + (* Remove the SCC containing the static region. + + For now, we don't handle cases where regions different from 'static + can live as long as 'static, so we check that if the group contains + 'static then it is the only region it contains, and then we filter + the group. + TODO: general support for 'static + *) + let sccs = + (* Find the SCC which contains the static region *) + let static_gr_id, static_scc = + List.find + (fun (_, scc) -> List.mem RStatic scc) + (SccId.Map.bindings sccs.sccs) + in + (* The SCC should only contain the 'static *) + assert (static_scc = [ RStatic ]); + (* Remove the group as well as references to this group from the + other SCCs *) + let { sccs; scc_deps } = sccs in + (* We have to change the indexing: + - if id < static_gr_id: we leave the id as it is + - if id = static_gr_id: we remove id + - if id > static_gr_id: we decrement it by one + *) + let static_i = SccId.to_int static_gr_id in + let convert_id (id : SccId.id) : SccId.id option = + let i = SccId.to_int id in + if i < static_i then Some id + else if i = static_i then None + else Some (SccId.of_int (i - 1)) + in + let sccs = + SccId.Map.of_list + (List.filter_map + (fun (id, rg_ids) -> + match convert_id id with + | None -> None + | Some id -> Some (id, rg_ids)) + (SccId.Map.bindings sccs)) + in + + let scc_deps = + List.filter_map + (fun (id, deps) -> + match convert_id id with + | None -> None + | Some id -> + let deps = List.filter_map convert_id (SccId.Set.elements deps) in + Some (id, SccId.Set.of_list deps)) + (SccId.Map.bindings scc_deps) + in + let scc_deps = SccId.Map.of_list scc_deps in + + { sccs; scc_deps } + in + + (* + * Compute the regions hierarchy + *) + List.filter_map + (fun (scc_id, scc) -> + (* The region id *) + let i = SccId.to_int scc_id in + let id = RegionGroupId.of_int i in + + (* Retrieve the set of regions in the group *) + let regions = + List.map + (fun r -> + match r with RVar r -> r | _ -> raise (Failure "Unreachable")) + scc + in + + (* Compute the set of parent region groups *) + let parents = + List.map + (fun id -> RegionGroupId.of_int (SccId.to_int id)) + (SccId.Set.elements (SccId.Map.find scc_id sccs.scc_deps)) + in + + (* Put together *) + Some { id; regions; parents }) + (SccId.Map.bindings sccs.sccs) + +let compute_regions_hierarchies (type_decls : type_decl TypeDeclId.Map.t) + (fun_decls : fun_decl FunDeclId.Map.t) : region_groups FunIdMap.t = + let regular = + List.map + (fun (fid, d) -> (FRegular fid, d.signature)) + (FunDeclId.Map.bindings fun_decls) + in + let assumed = + List.map + (fun (info : assumed_fun_info) -> (FAssumed info.fun_id, info.fun_sig)) + assumed_fun_infos + in + FunIdMap.of_list + (List.map + (fun (fid, sg) -> (fid, compute_regions_hierarchy_for_sig type_decls sg)) + (regular @ assumed)) diff --git a/compiler/ReorderDecls.ml b/compiler/ReorderDecls.ml index c82d625f..53c94ff4 100644 --- a/compiler/ReorderDecls.ml +++ b/compiler/ReorderDecls.ml @@ -1,4 +1,3 @@ -open Graph open Collections open SCC open Pure @@ -99,99 +98,9 @@ let group_reorder_fun_decls (decls : fun_decl list) : decls in - (* - * Create the dependency graph - *) - (* Convert the ids to vertices (i.e., injectively map ids to integers, and create - vertices labeled with those integers). - - Rem.: [Graph.create] is *imperative*: it generates a new vertex every time - it is called (!!). - *) - let module Graph = Pack.Digraph in - let id_to_vertex : Graph.V.t FunIdMap.t = - let cnt = ref 0 in - FunIdMap.of_list - (List.map - (fun id -> - let lbl = !cnt in - cnt := !cnt + 1; - (* We create a vertex *) - let v = Graph.V.create lbl in - (id, v)) - idl) - in - let vertex_to_id : fun_id IntMap.t = - IntMap.of_list - (List.map - (fun (fid, v) -> (Graph.V.label v, fid)) - (FunIdMap.bindings id_to_vertex)) - in - - let to_v id = FunIdMap.find id id_to_vertex in - let to_id v = IntMap.find (Graph.V.label v) vertex_to_id in - - let g = Graph.create () in - - (* Add the edges, first from the vertices to themselves, then between vertices *) - List.iter - (fun (fun_id, deps) -> - let v = to_v fun_id in - Graph.add_edge g v v; - FunIdSet.iter (fun dep_id -> Graph.add_edge g v (to_v dep_id)) deps) - deps; - - (* Compute the SCCs *) - let module Comp = Components.Make (Graph) in - let sccs = Comp.scc_list g in - - (* Convert the vertices to ids *) - let sccs = List.map (List.map to_id) sccs in - - log#ldebug - (lazy - ("group_reorder_fun_decls: SCCs:\n" - ^ Print.list_to_string (Print.list_to_string FunIdOrderedType.show_t) sccs - )); - - (* Sanity check *) - let _ = - (* Check that the SCCs are pairwise disjoint *) - assert (FunIdSet.pairwise_disjoint (List.map FunIdSet.of_list sccs)); - (* Check that all the ids are in the sccs *) - let scc_ids = FunIdSet.of_list (List.concat sccs) in - - log#ldebug - (lazy - ("group_reorder_fun_decls: sanity check:" ^ "\n- ids : " - ^ FunIdSet.show ids ^ "\n- scc_ids: " ^ FunIdSet.show scc_ids)); - - assert (FunIdSet.equal scc_ids ids) - in - - log#ldebug - (lazy - ("group_reorder_fun_decls: reordered SCCs:\n" - ^ Print.list_to_string (Print.list_to_string FunIdOrderedType.show_t) sccs - )); - - (* Reorder *) - let module Reorder = SCC.Make (FunIdOrderedType) in - let id_deps = - FunIdMap.of_list - (List.map (fun (fid, deps) -> (fid, FunIdSet.elements deps)) deps) - in - let sccs = Reorder.reorder_sccs id_deps idl sccs in - - (* Sanity check *) - let _ = - (* Check that the SCCs are pairwise disjoint *) - let sccs = List.map snd (SccId.Map.bindings sccs.sccs) in - assert (FunIdSet.pairwise_disjoint (List.map FunIdSet.of_list sccs)); - (* Check that all the ids are in the sccs *) - let scc_ids = FunIdSet.of_list (List.concat sccs) in - assert (FunIdSet.equal scc_ids ids) - in + (* Compute the ordered SCCs *) + let module Scc = SCC.Make (FunIdOrderedType) in + let sccs = Scc.compute deps in (* Group the declarations *) let deps = FunIdMap.of_list deps in diff --git a/compiler/SCC.ml b/compiler/SCC.ml index d9a4cd3e..150821ad 100644 --- a/compiler/SCC.ml +++ b/compiler/SCC.ml @@ -6,8 +6,15 @@ module SccId = Identifiers.IdGen () (** The local logger *) let log = Logging.scc_log +(** A structure containing information about SCCs (strongly connected components) *) +type 'id sccs = { + sccs : 'id list SccId.Map.t; + scc_deps : SccId.Set.t SccId.Map.t; (** The dependencies between sccs *) +} +[@@deriving show] + (** A functor which provides functions to work on strongly connected components *) -module Make (Id : OrderedType) = struct +module MakeReorder (Id : OrderedType) = struct module IdMap = MakeMap (Id) module IdSet = MakeSet (Id) @@ -15,13 +22,6 @@ module Make (Id : OrderedType) = struct let pp_id = Id.pp_t - (** A structure containing information about SCCs (strongly connected components) *) - type sccs = { - sccs : id list SccId.Map.t; - scc_deps : SccId.Set.t SccId.Map.t; (** The dependencies between sccs *) - } - [@@deriving show] - (** The order in which Tarjan's algorithm generates the SCCs is arbitrary, while we want to keep as much as possible the original order (the order in which the user generated the ids). For this, we iterate through @@ -93,7 +93,7 @@ module Make (Id : OrderedType) = struct Charon project. *) let reorder_sccs (id_deps : Id.t list IdMap.t) (ids : Id.t list) - (sccs : Id.t list list) : sccs = + (sccs : Id.t list list) : id sccs = (* Map the identifiers to the SCC indices *) let id_to_scc = IdMap.of_list @@ -168,13 +168,114 @@ module Make (Id : OrderedType) = struct { sccs = tgt_sccs; scc_deps = tgt_deps } end +module Make (Id : OrderedType) = struct + module M = MakeMap (Id) + module S = MakeSet (Id) + + (** Compute the ordered SCC components for a graph, which is a map + from identifier to set of identifiers (which represent the set + of edges starting from an identifier). + *) + let compute (m : (Id.t * S.t) list) : Id.t sccs = + (* + * Create the dependency graph + *) + (* Compute the list/set of identifiers *) + let idl = List.map fst m in + let ids = S.of_list idl in + + (* Convert the ids to vertices (i.e., injectively map ids to integers, + and create vertices labeled with those integers). + + Rem.: [Graph.create] is *imperative*: it generates a new vertex every + time it is called (!!). For this reason, we first add all the vertices + we need, then add the edges. + *) + let open Graph in + let module IntMap = MakeMap (OrderedInt) in + let module Graph = Pack.Digraph in + let id_to_vertex : Graph.V.t M.t = + let cnt = ref 0 in + M.of_list + (List.map + (fun id -> + let lbl = !cnt in + cnt := !cnt + 1; + (* We create a vertex *) + let v = Graph.V.create lbl in + (id, v)) + idl) + in + let vertex_to_id : Id.t IntMap.t = + IntMap.of_list + (List.map + (fun (fid, v) -> (Graph.V.label v, fid)) + (M.bindings id_to_vertex)) + in + + let to_v id = M.find id id_to_vertex in + let to_id v = IntMap.find (Graph.V.label v) vertex_to_id in + + let g = Graph.create () in + + (* Add the edges, first from the vertices to themselves, then between + vertices. *) + List.iter + (fun (id, deps) -> + let v = to_v id in + Graph.add_edge g v v; + S.iter (fun dep_id -> Graph.add_edge g v (to_v dep_id)) deps) + m; + + (* Compute the SCCs *) + let module Comp = Components.Make (Graph) in + let sccs = Comp.scc_list g in + + (* Convert the vertices to ids *) + let sccs = List.map (List.map to_id) sccs in + + (* Sanity check *) + let _ = + (* Check that the SCCs are pairwise disjoint *) + assert (S.pairwise_disjoint (List.map S.of_list sccs)); + (* Check that all the ids are in the sccs *) + let scc_ids = S.of_list (List.concat sccs) in + + log#ldebug + (lazy + ("group_reorder_fun_decls: sanity check:" ^ "\n- ids : " + ^ S.show ids ^ "\n- scc_ids: " ^ S.show scc_ids)); + + assert (S.equal scc_ids ids) + in + + (* Reorder *) + let module Reorder = MakeReorder (Id) in + let id_deps = + M.of_list (List.map (fun (fid, deps) -> (fid, S.elements deps)) m) + in + let sccs = Reorder.reorder_sccs id_deps idl sccs in + + (* Sanity check *) + let _ = + (* Check that the SCCs are pairwise disjoint *) + let sccs = List.map snd (SccId.Map.bindings sccs.sccs) in + assert (S.pairwise_disjoint (List.map S.of_list sccs)); + (* Check that all the ids are in the sccs *) + let scc_ids = S.of_list (List.concat sccs) in + assert (S.equal scc_ids ids) + in + + sccs +end + (** Test - TODO: make "real" unit tests *) let _ = (* Check that some SCCs are correctly reordered *) let check_sccs (id_deps : (int * int list) list) (ids : int list) (sccs : int list list) (tgt_sccs : int list list) : unit = let module Ord = OrderedInt in - let module Reorder = Make (Ord) in + let module Reorder = MakeReorder (Ord) in let module Map = MakeMap (Ord) in let id_deps = Map.of_list id_deps in diff --git a/compiler/Substitute.ml b/compiler/Substitute.ml index 166c237a..45edc602 100644 --- a/compiler/Substitute.ml +++ b/compiler/Substitute.ml @@ -64,6 +64,10 @@ let generic_args_substitute (subst : subst) (g : T.generic_args) : let visitor = st_substitute_visitor subst in visitor#visit_generic_args () g +let predicates_substitute (subst : subst) (p : T.predicates) : T.predicates = + let visitor = st_substitute_visitor subst in + visitor#visit_predicates () p + let erase_regions_subst : subst = { r_subst = (fun _ -> T.RErased); @@ -351,7 +355,8 @@ let trait_type_constraint_substitute (subst : subst) let ty = visitor#visit_ty () ty in { T.trait_ref; generics; type_name; ty } -(** Substitute a function signature. +(** Substitute a function signature, together with the regions hierarchy + associated to that signature. **IMPORTANT:** this function doesn't normalize the types. *) @@ -360,7 +365,8 @@ let substitute_signature (asubst : T.RegionGroupId.id -> V.AbstractionId.id) (ty_subst : T.TypeVarId.id -> T.ty) (cg_subst : T.ConstGenericVarId.id -> T.const_generic) (tr_subst : T.TraitClauseId.id -> T.trait_instance_id) - (tr_self : T.trait_instance_id) (sg : A.fun_sig) : A.inst_fun_sig = + (tr_self : T.trait_instance_id) (sg : A.fun_sig) + (regions_hierarchy : T.region_groups) : A.inst_fun_sig = let r_subst' (r : T.region) : T.region = match r with | T.RStatic | T.RErased -> r @@ -375,7 +381,7 @@ let substitute_signature (asubst : T.RegionGroupId.id -> V.AbstractionId.id) let parents = List.map asubst rg.parents in ({ id; regions; parents } : A.abs_region_group) in - let regions_hierarchy = List.map subst_region_group sg.A.regions_hierarchy in + let regions_hierarchy = List.map subst_region_group regions_hierarchy in let trait_type_constraints = List.map (trait_type_constraint_substitute subst) diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml index 258b1cf2..922aa307 100644 --- a/compiler/SymbolicToPure.ml +++ b/compiler/SymbolicToPure.ml @@ -47,6 +47,7 @@ type fun_context = { llbc_fun_decls : A.fun_decl A.FunDeclId.Map.t; fun_sigs : fun_sig_named_outputs RegularFunIdNotLoopMap.t; (** *) fun_infos : FA.fun_info A.FunDeclId.Map.t; + regions_hierarchies : T.region_groups FunIdMap.t; } [@@deriving show] @@ -441,8 +442,8 @@ and translate_strait_instance_id (id : T.trait_instance_id) : trait_instance_id translate_trait_instance_id translate_sty id let translate_trait_clause (clause : T.trait_clause) : trait_clause = - let { T.clause_id; meta = _; trait_id; generics } = clause in - let generics = translate_sgeneric_args generics in + let { T.clause_id; meta = _; trait_id; clause_generics } = clause in + let generics = translate_sgeneric_args clause_generics in { clause_id; trait_id; generics } let translate_strait_type_constraint (ttc : T.trait_type_constraint) : @@ -848,11 +849,14 @@ let translate_fun_sig (decls_ctx : C.decls_ctx) (fun_id : A.fun_id) let fun_infos = decls_ctx.fun_ctx.fun_infos in let type_infos = decls_ctx.type_ctx.type_infos in (* Retrieve the list of parent backward functions *) + let regions_hierarchy = + FunIdMap.find fun_id decls_ctx.fun_ctx.regions_hierarchies + in let gid, parents = match bid with | None -> (None, T.RegionGroupId.Set.empty) | Some bid -> - let parents = list_ancestor_region_groups sg bid in + let parents = list_ancestor_region_groups regions_hierarchy bid in (Some bid, parents) in (* Is the function stateful, and can it fail? *) @@ -865,7 +869,7 @@ let translate_fun_sig (decls_ctx : C.decls_ctx) (fun_id : A.fun_id) (* Create the context *) let ctx = let region_groups = - List.map (fun (g : T.region_group) -> g.id) sg.regions_hierarchy + List.map (fun (g : T.region_group) -> g.id) regions_hierarchy in let ctx = InterpreterUtils.initialize_eval_context decls_ctx region_groups @@ -898,7 +902,7 @@ let translate_fun_sig (decls_ctx : C.decls_ctx) (fun_id : A.fun_id) assert (T.RegionGroupId.Set.is_empty parents); (* Small helper to translate types for backward functions *) let translate_back_ty_for_gid (gid : T.RegionGroupId.id) : T.ty -> ty option = - let rg = T.RegionGroupId.nth sg.regions_hierarchy gid in + let rg = T.RegionGroupId.nth regions_hierarchy gid in let regions = T.RegionId.Set.of_list rg.regions in let keep_region r = match r with @@ -2924,6 +2928,9 @@ let translate_fun_decl (ctx : bs_ctx) (body : S.expression option) : fun_decl = let basename = def.name in (* Retrieve the signature *) let signature = ctx.sg in + let regions_hierarchy = + FunIdMap.find (FRegular def_id) ctx.fun_context.regions_hierarchies + in (* Translate the body, if there is *) let body = match body with @@ -2965,7 +2972,7 @@ let translate_fun_decl (ctx : bs_ctx) (body : S.expression option) : fun_decl = | None -> [] | Some back_id -> let parents_ids = - list_ordered_ancestor_region_groups def.signature back_id + list_ordered_ancestor_region_groups regions_hierarchy back_id in let backward_ids = List.append parents_ids [ back_id ] in List.concat @@ -3069,6 +3076,10 @@ let translate_fun_signatures (decls_ctx : C.decls_ctx) let translate_one (fun_id : A.fun_id) (input_names : string option list) (sg : A.fun_sig) : (regular_fun_id_not_loop * fun_sig_named_outputs) list = + (* Retrieve the regions hierarchy *) + let regions_hierarchy = + FunIdMap.find fun_id decls_ctx.fun_ctx.regions_hierarchies + in (* The forward function *) let fwd_sg = translate_fun_sig decls_ctx fun_id sg input_names None in let fwd_id = (fun_id, None) in @@ -3081,7 +3092,7 @@ let translate_fun_signatures (decls_ctx : C.decls_ctx) in let id = (fun_id, Some rg.id) in (id, tsg)) - sg.regions_hierarchy + regions_hierarchy in (* Return *) (fwd_id, fwd_sg) :: back_sgs diff --git a/compiler/Translate.ml b/compiler/Translate.ml index 9a6addee..2aedb544 100644 --- a/compiler/Translate.ml +++ b/compiler/Translate.ml @@ -91,6 +91,7 @@ let translate_function_to_pure (trans_ctx : trans_ctx) SymbolicToPure.llbc_fun_decls = trans_ctx.fun_ctx.fun_decls; fun_sigs; fun_infos = trans_ctx.fun_ctx.fun_infos; + regions_hierarchies = trans_ctx.fun_ctx.regions_hierarchies; } in let global_context = @@ -263,9 +264,11 @@ let translate_function_to_pure (trans_ctx : trans_ctx) (* Translate *) SymbolicToPure.translate_fun_decl ctx (Some symbolic) in - let pure_backwards = - List.map translate_backward fdef.signature.regions_hierarchy + let regions_hierarchy = + LlbcAstUtils.FunIdMap.find (FRegular fdef.def_id) + fun_context.regions_hierarchies in + let pure_backwards = List.map translate_backward regions_hierarchy in (* Return *) (pure_forward, pure_backwards) diff --git a/compiler/dune b/compiler/dune index 648c7325..bc3cc718 100644 --- a/compiler/dune +++ b/compiler/dune @@ -57,6 +57,7 @@ Pure PureTypeCheck PureUtils + RegionsHierarchy ReorderDecls SCC Scalars -- cgit v1.2.3