diff options
-rw-r--r-- | compiler/Contexts.ml | 1 | ||||
-rw-r--r-- | compiler/Interpreter.ml | 11 | ||||
-rw-r--r-- | compiler/InterpreterLoops.ml | 283 | ||||
-rw-r--r-- | compiler/PrePasses.ml | 116 | ||||
-rw-r--r-- | compiler/Values.ml | 30 |
5 files changed, 387 insertions, 54 deletions
diff --git a/compiler/Contexts.ml b/compiler/Contexts.ml index 69c4ec3b..55baa6a4 100644 --- a/compiler/Contexts.ml +++ b/compiler/Contexts.ml @@ -259,6 +259,7 @@ type eval_ctx = { type_context : type_context; fun_context : fun_context; global_context : global_context; + region_groups : RegionGroupId.id list; type_vars : type_var list; env : env; ended_regions : RegionId.Set.t; diff --git a/compiler/Interpreter.ml b/compiler/Interpreter.ml index ec1b6260..4b030088 100644 --- a/compiler/Interpreter.ml +++ b/compiler/Interpreter.ml @@ -29,12 +29,14 @@ let compute_type_fun_global_contexts (m : A.crate) : let initialize_eval_context (type_context : C.type_context) (fun_context : C.fun_context) (global_context : C.global_context) - (type_vars : T.type_var list) : C.eval_ctx = + (region_groups : T.RegionGroupId.id list) (type_vars : T.type_var list) : + C.eval_ctx = C.reset_global_counters (); { C.type_context; C.fun_context; C.global_context; + C.region_groups; C.type_vars; C.env = [ C.Frame ]; C.ended_regions = T.RegionId.Set.empty; @@ -69,9 +71,12 @@ let initialize_symbolic_context_for_fun (type_context : C.type_context) * *) let sg = fdef.signature in (* Create the context *) + let region_groups = + List.map (fun (g : T.region_var_group) -> g.id) sg.regions_hierarchy + in let ctx = initialize_eval_context type_context fun_context global_context - sg.type_params + region_groups sg.type_params in (* Instantiate the signature *) let type_params = List.map (fun tv -> T.TypeVar tv.T.index) sg.type_params in @@ -312,7 +317,7 @@ module Test = struct compute_type_fun_global_contexts crate in let ctx = - initialize_eval_context type_context fun_context global_context [] + initialize_eval_context type_context fun_context global_context [] [] in (* Insert the (uninitialized) local variables *) diff --git a/compiler/InterpreterLoops.ml b/compiler/InterpreterLoops.ml index 53005bf3..e68790d4 100644 --- a/compiler/InterpreterLoops.ml +++ b/compiler/InterpreterLoops.ml @@ -245,7 +245,7 @@ let compute_abs_borrows_loans_maps (no_duplicates : bool) (** Destructure all the new abstractions *) let destructure_new_abs (loop_id : V.LoopId.id) (old_abs_ids : V.AbstractionId.Set.t) (ctx : C.eval_ctx) : C.eval_ctx = - let abs_kind = V.Loop loop_id in + let abs_kind = V.Loop (loop_id, None) in let can_end = false in let destructure_shared_values = true in let is_fresh_abs_id (id : V.AbstractionId.id) : bool = @@ -336,7 +336,7 @@ let collapse_ctx (loop_id : V.LoopId.id) ("collapse_ctx:\n\n- fixed_ids:\n" ^ show_ids_sets old_ids ^ "\n\n- ctx0:\n" ^ eval_ctx_to_string ctx0 ^ "\n\n")); - let abs_kind = V.Loop loop_id in + let abs_kind = V.Loop (loop_id, None) in let can_end = false in let destructure_shared_values = true in let is_fresh_abs_id (id : V.AbstractionId.id) : bool = @@ -966,7 +966,7 @@ module MakeJoinMatcher (S : MatchJoinState) : Matcher = struct let abs = { V.abs_id = C.fresh_abstraction_id (); - kind = V.Loop S.loop_id; + kind = V.Loop (S.loop_id, None); can_end = false; parents = V.AbstractionId.Set.empty; original_parents = []; @@ -1020,7 +1020,7 @@ module MakeJoinMatcher (S : MatchJoinState) : Matcher = struct let abs = { V.abs_id = C.fresh_abstraction_id (); - kind = V.Loop S.loop_id; + kind = V.Loop (S.loop_id, None); can_end = false; parents = V.AbstractionId.Set.empty; original_parents = []; @@ -1124,7 +1124,7 @@ module MakeJoinMatcher (S : MatchJoinState) : Matcher = struct else raise (ValueMatchFailure (LoanInRight id))) | None -> (* Convert the value to an abstraction *) - let abs_kind = V.Loop S.loop_id in + let abs_kind = V.Loop (S.loop_id, None) in let can_end = false in let destructure_shared_values = true in let absl = @@ -1146,15 +1146,8 @@ module MakeJoinMatcher (S : MatchJoinState) : Matcher = struct let match_avalues _ _ = raise (Failure "Unreachable") end -(** Collapse an environment, merging the duplicated borrows/loans. - - This function simply calls {!collapse_ctx} with the proper merging functions. - - We do this because when we join environments, we may introduce duplicated - loans and borrows. See the explanations for {!join_ctxs}. - *) -let collapse_ctx_with_merge (loop_id : V.LoopId.id) (old_ids : ids_sets) - (ctx : C.eval_ctx) : C.eval_ctx = +let mk_collapse_ctx_merge_duplicate_funs (loop_id : V.LoopId.id) + (ctx : C.eval_ctx) : merge_duplicates_funcs = (* Rem.: the merge functions raise exceptions (that we catch). *) let module S : MatchJoinState = struct let ctx = ctx @@ -1228,15 +1221,30 @@ let collapse_ctx_with_merge (loop_id : V.LoopId.id) (old_ids : ids_sets) let value = V.ALoan (V.ASharedLoan (ids, sv, child)) in { V.value; ty } in - let merge_funcs = - { - merge_amut_borrows; - merge_ashared_borrows; - merge_amut_loans; - merge_ashared_loans; - } - in - try collapse_ctx loop_id (Some merge_funcs) old_ids ctx + { + merge_amut_borrows; + merge_ashared_borrows; + merge_amut_loans; + merge_ashared_loans; + } + +let merge_abstractions (loop_id : V.LoopId.id) (abs_kind : V.abs_kind) + (can_end : bool) (ctx : C.eval_ctx) (aid0 : V.AbstractionId.id) + (aid1 : V.AbstractionId.id) : C.eval_ctx * V.AbstractionId.id = + let merge_funs = mk_collapse_ctx_merge_duplicate_funs loop_id ctx in + merge_abstractions abs_kind can_end (Some merge_funs) ctx aid0 aid1 + +(** Collapse an environment, merging the duplicated borrows/loans. + + This function simply calls {!collapse_ctx} with the proper merging functions. + + We do this because when we join environments, we may introduce duplicated + loans and borrows. See the explanations for {!join_ctxs}. + *) +let collapse_ctx_with_merge (loop_id : V.LoopId.id) (old_ids : ids_sets) + (ctx : C.eval_ctx) : C.eval_ctx = + let merge_funs = mk_collapse_ctx_merge_duplicate_funs loop_id ctx in + try collapse_ctx loop_id (Some merge_funs) old_ids ctx with ValueMatchFailure _ -> raise (Failure "Unexpected") (** Join two contexts. @@ -1450,6 +1458,7 @@ let join_ctxs (loop_id : V.LoopId.id) (fixed_ids : ids_sets) (ctx0 : C.eval_ctx) C.type_context; fun_context; global_context; + region_groups; type_vars; env = _; ended_regions = ended_regions0; @@ -1460,6 +1469,7 @@ let join_ctxs (loop_id : V.LoopId.id) (fixed_ids : ids_sets) (ctx0 : C.eval_ctx) C.type_context = _; fun_context = _; global_context = _; + region_groups = _; type_vars = _; env = _; ended_regions = ended_regions1; @@ -1472,6 +1482,7 @@ let join_ctxs (loop_id : V.LoopId.id) (fixed_ids : ids_sets) (ctx0 : C.eval_ctx) C.type_context; fun_context; global_context; + region_groups; type_vars; env; ended_regions; @@ -1830,6 +1841,10 @@ let match_ctxs (check_equiv : bool) (fixed_ids : ids_sets) (ctx0 : C.eval_ctx) (* Rem.: this function raises exceptions of type [Distinct] *) let match_abstractions (abs0 : V.abs) (abs1 : V.abs) : unit = + log#ldebug + (lazy + ("match_ctxs: match_abstractions: " ^ "\n\n- abs0: " ^ V.show_abs abs0 + ^ "\n\n- abs0: " ^ V.show_abs abs1)); let { V.abs_id = abs_id0; kind = kind0; @@ -1917,8 +1932,11 @@ let match_ctxs (check_equiv : bool) (fixed_ids : ids_sets) (ctx0 : C.eval_ctx) (* Continue *) match_envs env0' env1' | C.Abs abs0 :: env0', C.Abs abs1 :: env1' -> + log#ldebug (lazy "ctxs_are_equivalent: match_envs: matching abs"); (* Same as for the dummy values: there are two cases *) if V.AbstractionId.Set.mem abs0.abs_id fixed_ids.aids then ( + log#ldebug + (lazy "ctxs_are_equivalent: match_envs: matching abs: fixed abs"); (* Still in the prefix: the abstractions must be the same *) assert (abs0 = abs1); (* Their ids must be fixed *) @@ -1927,6 +1945,9 @@ let match_ctxs (check_equiv : bool) (fixed_ids : ids_sets) (ctx0 : C.eval_ctx) (* Continue *) match_envs env0' env1') else ( + log#ldebug + (lazy + "ctxs_are_equivalent: match_envs: matching abs: not fixed abs"); (* Match the values *) match_abstractions abs0 abs1; (* Continue *) @@ -2188,12 +2209,13 @@ let compute_loop_entry_fixed_point (config : C.config) (loop_id : V.LoopId.id) ctxs_are_equivalent fixed_ids ctx1 ctx2 in let max_num_iter = Config.loop_fixed_point_max_num_iters in - let rec compute_fixed_point (ctx : C.eval_ctx) (i : int) : C.eval_ctx = + let rec compute_fixed_point (ctx : C.eval_ctx) (i0 : int) (i : int) : + C.eval_ctx = if i = 0 then raise (Failure - ("Could not compute a loop fixed point in " - ^ string_of_int max_num_iter ^ " iterations")) + ("Could not compute a loop fixed point in " ^ string_of_int i0 + ^ " iterations")) else (* Evaluate the loop body to register the different contexts upon reentry *) let _ = eval_loop_body cf_exit_loop_body ctx in @@ -2211,35 +2233,206 @@ let compute_loop_entry_fixed_point (config : C.config) (loop_id : V.LoopId.id) ^ "\n\n")); (* Check if we reached a fixed point: if not, iterate *) - if equiv_ctxs ctx ctx1 then ctx1 else compute_fixed_point ctx1 (i - 1) + if equiv_ctxs ctx ctx1 then ctx1 else compute_fixed_point ctx1 i0 (i - 1) in - let fp = compute_fixed_point ctx0 max_num_iter in - let fixed_ids = compute_fixed_ids (Option.get !ctx_upon_entry) fp in - (* For now, all the new abstractions in the fixed-point have the same region - group (of id 0): we want each one of them to have a unique region group - (because we will translate each one of those abstractions to a pair - forward function/backward function). - *) - let region_map = ref T.RegionId.Map.empty in - let get_rid (rid : T.RegionId.id) : T.RegionId.id = - if T.RegionId.Set.mem rid fixed_ids.rids then rid - else - match T.RegionId.Map.find_opt rid !region_map with - | Some rid -> rid + let fp = compute_fixed_point ctx0 max_num_iter max_num_iter in + + (* Make sure we have exactly one loop abstraction per function region (merge + abstractions accordingly) *) + let fp = + (* List the loop abstractions in the fixed-point, and set all the abstractions + as endable *) + let fp_aids, add_aid, _mem_aid = V.AbstractionId.Set.mk_stateful_set () in + + let list_loop_abstractions = + object + inherit [_] C.map_eval_ctx as super + + method! visit_abs env abs = + match abs.kind with + | Loop (loop_id', _) -> + assert (loop_id' = loop_id); + add_aid abs.abs_id; + { abs with can_end = true } + | _ -> super#visit_abs env abs + end + in + let fp = list_loop_abstractions#visit_eval_ctx () fp in + + (* For every input region group: + * - evaluate until we get to a [return] + * - end the input abstraction corresponding to the input region group + * - find which loop abstractions end at that moment + * + * [fp_ended_aids] links region groups to sets of ended abstractions. + *) + let fp_ended_aids = ref T.RegionGroupId.Map.empty in + let add_ended_aids (rg_id : T.RegionGroupId.id) + (aids : V.AbstractionId.Set.t) : unit = + match T.RegionGroupId.Map.find_opt rg_id !fp_ended_aids with | None -> - let nrid = C.fresh_region_id () in - region_map := T.RegionId.Map.add rid nrid !region_map; - nrid + fp_ended_aids := T.RegionGroupId.Map.add rg_id aids !fp_ended_aids + | Some aids' -> + let aids = V.AbstractionId.Set.union aids aids' in + fp_ended_aids := T.RegionGroupId.Map.add rg_id aids !fp_ended_aids + in + let cf_loop : st_m_fun = + fun res ctx -> + match res with + | Continue _ | Panic -> + (* We don't want to generate anything *) + None + | Break _ -> + (* We enforce that we can't get there: see {!PrePasses.remove_loop_breaks} *) + raise (Failure "Unreachable") + | Unit | EndEnterLoop | EndContinue -> + (* For why we can't get [Unit], see the comments inside {!eval_loop_concrete}. + For [EndEnterLoop] and [EndContinue]: we don't support nested loops for now. + *) + raise (Failure "Unreachable") + | Return -> + (* Should we consume the return value and pop the frame? + * If we check in [Interpreter] that the loop abstraction we end is + * indeed the correct one, I think it is sound to under-approximate here + * (and it shouldn't make any difference). + *) + let _ = + List.iter + (fun rg_id -> + (* Lookup the input abstraction - we use the fact that the + abstractions should have been introduced in a specific + order (and we check that it is indeed the case) *) + let abs_id = + V.AbstractionId.of_int (T.RegionGroupId.to_int rg_id) + in + let abs = C.ctx_lookup_abs ctx abs_id in + assert (abs.kind = V.SynthInput rg_id); + (* End this abstraction *) + let ctx = + InterpreterBorrows.end_abstraction_no_synth config abs_id ctx + in + (* Explore the context, and check which abstractions are not there anymore *) + let ids = compute_context_ids ctx in + let ended_ids = V.AbstractionId.Set.diff !fp_aids ids.aids in + add_ended_aids rg_id ended_ids) + ctx.region_groups + in + (* We don't want to generate anything *) + None + in + let _ = eval_loop_body cf_loop fp in + + (* Check that the sets of abstractions we need to end per region group are pairwise + * disjoint *) + let aids_union = ref V.AbstractionId.Set.empty in + let _ = + T.RegionGroupId.Map.iter + (fun _ ids -> + assert (V.AbstractionId.Set.disjoint !aids_union ids); + aids_union := V.AbstractionId.Set.union ids !aids_union) + !fp_ended_aids + in + assert (!aids_union = !fp_aids); + + (* Merge the abstractions which need to be merged *) + let fp = ref fp in + let _ = + T.RegionGroupId.Map.iter + (fun rg_id ids -> + let ids = V.AbstractionId.Set.elements ids in + (* Retrieve the first id of the group *) + match ids with + | [] -> + (* This would be rather unexpected... let's fail for now to see + in which situations this happens *) + raise (Failure "Unexpected") + | id0 :: ids -> + let id0 = ref id0 in + (* Add the proper region group into the abstraction *) + let abs_kind = V.Loop (loop_id, Some rg_id) in + let abs = C.ctx_lookup_abs !fp !id0 in + let abs = { abs with V.kind = abs_kind } in + let fp', _ = C.ctx_subst_abs !fp !id0 abs in + fp := fp'; + (* Merge all the abstractions into this one *) + List.iter + (fun id -> + try + let fp', id0' = + merge_abstractions loop_id abs_kind false !fp !id0 id + in + fp := fp'; + id0 := id0'; + () + with ValueMatchFailure _ -> raise (Failure "Unexpected")) + ids) + !fp_ended_aids + in + + (* Reset the loop abstracions as not endable *) + let list_loop_abstractions (remove_rg_id : bool) = + object + inherit [_] C.map_eval_ctx as super + + method! visit_abs env abs = + match abs.kind with + | Loop (loop_id', _) -> + assert (loop_id' = loop_id); + let kind = + if remove_rg_id then V.Loop (loop_id, None) else abs.kind + in + { abs with can_end = false; kind } + | _ -> super#visit_abs env abs + end + in + let update_kinds_can_end (remove_rg_id : bool) ctx = + (list_loop_abstractions remove_rg_id)#visit_eval_ctx () ctx + in + let fp = update_kinds_can_end false !fp in + + (* Check that we still have a fixed point - we simply call [compute_fixed_point] + while allowing exactly one iteration to see if it fails *) + let _ = compute_fixed_point (update_kinds_can_end true fp) 1 1 in + + (* Return *) + fp in + let fixed_ids = compute_fixed_ids (Option.get !ctx_upon_entry) fp in + + (* Return *) + (fp, fixed_ids) + +(* +(** Introduce region groups in the loop abstractions. + + Initially, the new abstractions in the fixed-point have no region group. + We want each one of them to have a unique region group (because we will + translate each one of those abstractions to a pair forward + function/backward function). + *) +let ctx_add_loop_region_groups (loop_id : V.LoopId.id) (fp : C.eval_ctx) : + C.eval_ctx = + let _, fresh_rid = T.RegionGroupId.fresh_stateful_generator () in + let introduce_fresh_rids = object inherit [_] C.map_eval_ctx - method! visit_region_id _ rid = get_rid rid + + method! visit_abs _ abs = + match abs.kind with + | Loop (loop_id', rg_id) -> + assert (loop_id' = loop_id); + assert (rg_id = None); + let rg_id = Some (fresh_rid ()) in + let kind = V.Loop (loop_id, rg_id) in + { abs with V.kind } + | _ -> abs end in let fp_env = List.rev (introduce_fresh_rids#visit_env () (List.rev fp.env)) in let fp = { fp with env = fp_env } in - (fp, fixed_ids) + fp + *) (** Split an environment between the fixed abstractions, values, etc. and the new abstractions, values, etc. diff --git a/compiler/PrePasses.ml b/compiler/PrePasses.ml index 082a81ba..be154539 100644 --- a/compiler/PrePasses.ml +++ b/compiler/PrePasses.ml @@ -8,6 +8,7 @@ module E = Expressions module C = Contexts module A = LlbcAst module L = Logging +open Utils open LlbcAstUtils let log = L.pre_passes_log @@ -55,6 +56,8 @@ let filter_drop_assigns (f : A.fun_decl) : A.fun_decl = merge branches during the symbolic execution in some quite common cases where doing a merge is actually not necessary and leads to an ugly translation. + TODO: this is useless + For instance, it performs the following transformation: {[ if b { @@ -145,8 +148,119 @@ let remove_useless_cf_merges (crate : A.crate) (f : A.fun_decl) : A.fun_decl = ^ "\n")); f +(** This pass restructures the control-flow by inserting all the statements + which occur after loops *inside* the loops, thus removing the need to + have breaks (we later check that we removed all the breaks). + + This is needed because of the way we perform the symbolic execution + on the loops for now. + + Rem.: we check that there are no nested loops (all the breaks must break + to the first outer loop, and the statements we insert inside the loops + mustn't contain breaks themselves). + + For instance, it performs the following transformation: + {[ + loop { + if b { + ... + continue 0; + } + else { + ... + break 0; + } + }; + x := x + 1; + return; + + ~~> + + loop { + if b { + ... + continue 0; + } + else { + ... + x := x + 1; + return; + } + }; + ]} + *) +let remove_loop_breaks (crate : A.crate) (f : A.fun_decl) : A.fun_decl = + let f0 = f in + + (* Check that a statement doesn't contain loops, breaks or continues *) + let statement_has_no_loop_break_continue (st : A.statement) : bool = + let obj = + object + inherit [_] A.iter_statement + method! visit_Loop _ _ = raise Found + method! visit_Break _ _ = raise Found + method! visit_Continue _ _ = raise Found + end + in + try + obj#visit_statement () st; + true + with Found -> false + in + + (* Replace a break statement with another statement (we check that the + break statement breaks exactly one level, and that there are no nested + loops. + *) + let replace_breaks_with (st : A.statement) (nst : A.statement) : A.statement = + let obj = + object + inherit [_] A.map_statement as super + + method! visit_Loop entered_loop loop = + assert (not entered_loop); + super#visit_Loop true loop + + method! visit_Break _ i = + assert (i = 0); + nst.content + end + in + obj#visit_statement false st + in + + (* The visitor *) + let obj = + object + inherit [_] A.map_statement as super + + method! visit_Sequence env st1 st2 = + match st1.content with + | Loop _ -> + assert (statement_has_no_loop_break_continue st2); + (replace_breaks_with st1 st2).content + | _ -> super#visit_Sequence env st1 st2 + end + in + + (* Map *) + let body = + match f.body with + | Some body -> Some { body with body = obj#visit_statement () body.body } + | None -> None + in + let f = { f with body } in + log#ldebug + (lazy + ("Before/after [remove_loop_breaks]:\n" + ^ Print.Crate.crate_fun_decl_to_string crate f0 + ^ "\n\n" + ^ Print.Crate.crate_fun_decl_to_string crate f + ^ "\n")); + f + let apply_passes (crate : A.crate) : A.crate = - let passes = [ remove_useless_cf_merges crate ] in + let passes = [ remove_loop_breaks crate ] in let functions = List.fold_left (fun fl pass -> List.map pass fl) crate.functions passes in diff --git a/compiler/Values.ml b/compiler/Values.ml index 6af59087..e737103a 100644 --- a/compiler/Values.ml +++ b/compiler/Values.ml @@ -920,9 +920,29 @@ type abs_kind = See the explanations for [SynthInput]. *) - | Loop of LoopId.id (** The abstraction corresponds to a loop *) + | Loop of (LoopId.id * RegionGroupId.id option) + (** The abstraction corresponds to a loop. + + The region group id is initially [None]. + After we computed a fixed point, we give a unique region group + identifier for each loop abstraction. + *) [@@deriving show, ord] +(** Ancestor for {!abs} iter visitor *) +class ['self] iter_abs_base = + object (_self : 'self) + inherit [_] iter_typed_avalue + method visit_abs_kind : 'env -> abs_kind -> unit = fun _ _ -> () + end + +(** Ancestor for {!abs} map visitor *) +class ['self] map_abs_base = + object (_self : 'self) + inherit [_] map_typed_avalue + method visit_abs_kind : 'env -> abs_kind -> abs_kind = fun _ x -> x + end + (** Abstractions model the parts in the borrow graph where the borrowing relations have been abstracted because of a function call. @@ -931,8 +951,8 @@ type abs_kind = *) type abs = { abs_id : abstraction_id; - kind : (abs_kind[@opaque]); - can_end : (bool[@opaque]); + kind : abs_kind; + can_end : bool; (** Controls whether the region can be ended or not. This allows to "pin" some regions, and is useful when generating @@ -959,7 +979,7 @@ type abs = { { name = "iter_abs"; variety = "iter"; - ancestors = [ "iter_typed_avalue" ]; + ancestors = [ "iter_abs_base" ]; nude = true (* Don't inherit {!VisitorsRuntime.iter} *); concrete = true; }, @@ -967,7 +987,7 @@ type abs = { { name = "map_abs"; variety = "map"; - ancestors = [ "map_typed_avalue" ]; + ancestors = [ "map_abs_base" ]; nude = true (* Don't inherit {!VisitorsRuntime.iter} *); concrete = true; }] |