diff options
author | Son Ho | 2022-11-28 15:57:30 +0100 |
---|---|---|
committer | Son HO | 2023-02-03 11:21:46 +0100 |
commit | 1b4adc1056a97f52d0bf1661611efb6d4727212f (patch) | |
tree | 172efca9e912c96a2c1b01edc3d288b3b85461b0 | |
parent | 59596561873162d632f3d3091485b32a76427ee9 (diff) |
Make progress on environments matches and joins
-rw-r--r-- | compiler/InterpreterBorrows.ml | 6 | ||||
-rw-r--r-- | compiler/InterpreterBorrows.mli | 4 | ||||
-rw-r--r-- | compiler/InterpreterLoops.ml | 396 | ||||
-rw-r--r-- | compiler/Substitute.ml | 143 | ||||
-rw-r--r-- | compiler/Values.ml | 88 |
5 files changed, 563 insertions, 74 deletions
diff --git a/compiler/InterpreterBorrows.ml b/compiler/InterpreterBorrows.ml index b85f6692..28caf6e6 100644 --- a/compiler/InterpreterBorrows.ml +++ b/compiler/InterpreterBorrows.ml @@ -1417,6 +1417,9 @@ let get_cf_ctx_no_synth (f : cm_fun) (ctx : C.eval_ctx) : C.eval_ctx = let end_borrow_no_synth config id ctx = get_cf_ctx_no_synth (end_borrow config id) ctx +let end_borrows_no_synth config ids ctx = + get_cf_ctx_no_synth (end_borrows config ids) ctx + let end_abstraction_no_synth config id ctx = get_cf_ctx_no_synth (end_abstraction config id) ctx @@ -1770,8 +1773,7 @@ let convert_value_to_abstractions (abs_kind : V.abs_kind) (can_end : bool) [ { V.value; ty } ]) | V.Symbolic _ -> (* For now, we force all the symbolic values containing borrows to - be eagerly expanded *) - (* We don't support nested borrows for now *) + be eagerly expanded, and we don't support nested borrows *) assert (not (value_has_borrows ctx v.V.value)); (* Return nothing *) [] diff --git a/compiler/InterpreterBorrows.mli b/compiler/InterpreterBorrows.mli index a1a8fb0c..0d1be9c2 100644 --- a/compiler/InterpreterBorrows.mli +++ b/compiler/InterpreterBorrows.mli @@ -33,6 +33,10 @@ val end_abstractions : C.config -> V.AbstractionId.Set.t -> cm_fun (** End a borrow and return the resulting environment, ignoring synthesis *) val end_borrow_no_synth : C.config -> V.BorrowId.id -> C.eval_ctx -> C.eval_ctx +(** End a set of borrows and return the resulting environment, ignoring synthesis *) +val end_borrows_no_synth : + C.config -> V.BorrowId.Set.t -> C.eval_ctx -> C.eval_ctx + (** End an abstraction and return the resulting environment, ignoring synthesis *) val end_abstraction_no_synth : C.config -> V.AbstractionId.id -> C.eval_ctx -> C.eval_ctx diff --git a/compiler/InterpreterLoops.ml b/compiler/InterpreterLoops.ml index 702420cd..08607deb 100644 --- a/compiler/InterpreterLoops.ml +++ b/compiler/InterpreterLoops.ml @@ -28,8 +28,38 @@ type cnt_thresholds = { sid : V.SymbolicValueId.id; bid : V.BorrowId.id; did : C.DummyVarId.id; + rid : T.RegionId.id; } +(* TODO: document. + TODO: we might not use the bounds properly, use sets instead. + TODO: actually, bounds are good +*) +type match_ctx = { + ctx : C.eval_ctx; + aids : V.AbstractionId.Set.t; + sids : V.SymbolicValueId.Set.t; + bids : V.BorrowId.Set.t; +} + +let mk_match_ctx (ctx : C.eval_ctx) : match_ctx = + let aids = V.AbstractionId.Set.empty in + let sids = V.SymbolicValueId.Set.empty in + let bids = V.BorrowId.Set.empty in + { ctx; aids; sids; bids } + +type updt_env_kind = + | EndAbsInSrc of V.AbstractionId.id + | EndBorrowsInSrc of V.BorrowId.Set.t + | EndBorrowInSrc of V.BorrowId.id + | EndBorrowInTgt of V.BorrowId.id + | EndBorrowsInTgt of V.BorrowId.Set.t + +(** Utility exception *) +exception ValueMatchFailure of updt_env_kind + +type joined_ctx_or_update = (match_ctx, updt_env_kind) result + (** Union Find *) module UnionFind = UF.Make (UF.StoreMap) @@ -248,26 +278,328 @@ let collapse_ctx (loop_id : V.LoopId.id) (thresh : cnt_thresholds) (* Return the new context *) !ctx -(* TODO: document. - TODO: we might not use the bounds properly, use sets instead. -*) -type match_ctx = { - ctx : C.eval_ctx; - aids : V.AbstractionId.Set.t; - sids : V.SymbolicValueId.Set.t; - bids : V.BorrowId.Set.t; -} +let rec match_types (check_regions : 'r -> 'r -> unit) (ctx : C.eval_ctx) + (ty0 : 'r T.ty) (ty1 : 'r T.ty) : unit = + let match_rec = match_types check_regions ctx in + match (ty0, ty1) with + | Adt (id0, regions0, tys0), Adt (id1, regions1, tys1) -> + assert (id0 = id1); + List.iter + (fun (id0, id1) -> check_regions id0 id1) + (List.combine regions0 regions1); + List.iter (fun (ty0, ty1) -> match_rec ty0 ty1) (List.combine tys0 tys1) + | TypeVar vid0, TypeVar vid1 -> assert (vid0 = vid1) + | Bool, Bool | Char, Char | Never, Never | Str, Str -> () + | Integer int_ty0, Integer int_ty1 -> assert (int_ty0 = int_ty1) + | Array ty0, Array ty1 | Slice ty0, Slice ty1 -> match_rec ty0 ty1 + | Ref (r0, ty0, k0), Ref (r1, ty1, k1) -> + check_regions r0 r1; + match_rec ty0 ty1; + assert (k0 = k1) + | _ -> raise (Failure "Unreachable") + +let match_rtypes (rid_map : T.RegionId.InjSubst.t ref) (ctx : C.eval_ctx) + (ty0 : T.rty) (ty1 : T.rty) : unit = + let lookup_rid (id : T.RegionId.id) : T.RegionId.id = + T.RegionId.InjSubst.find_with_default id id !rid_map + in + let check_regions r0 r1 = + match (r0, r1) with + | T.Static, T.Static -> () + | T.Var id0, T.Var id1 -> + let id0 = lookup_rid id0 in + assert (id0 = id1) + | _ -> raise (Failure "Unexpected") + in + match_types check_regions ctx ty0 ty1 -let mk_match_ctx (ctx : C.eval_ctx) : match_ctx = - let aids = V.AbstractionId.Set.empty in - let sids = V.SymbolicValueId.Set.empty in - let bids = V.BorrowId.Set.empty in - { ctx; aids; sids; bids } +(** This function raises exceptions of kind {!ValueMatchFailue}. + + [convertible]: the function updates it to [false] if the result of the + merge is not the result of an alpha-conversion. For instance, if we + match two primitive values which are not equal, and thus introduce a + symbolic value for the result: + {[ + 0 : u32, 1 : u32 ~~> s : u32 where s fresh + ]} + *) +let rec match_typed_values (config : C.config) (thresh : cnt_thresholds) + (convertible : bool ref) (rid_map : T.RegionId.InjSubst.t ref) + (bid_map : V.BorrowId.InjSubst.t ref) + (sid_map : V.SymbolicValueId.InjSubst.t ref) (ctx : C.eval_ctx) + (v0 : V.typed_value) (v1 : V.typed_value) : V.typed_value = + let match_rec = + match_typed_values config thresh convertible rid_map bid_map sid_map ctx + in + let lookup_bid (id : V.BorrowId.id) : V.BorrowId.id = + V.BorrowId.InjSubst.find_with_default id id !bid_map + in + let lookup_bids (ids : V.BorrowId.Set.t) : V.BorrowId.Set.t = + V.BorrowId.Set.map lookup_bid ids + in + let map_bid (id0 : V.BorrowId.id) (id1 : V.BorrowId.id) : V.BorrowId.id = + assert (V.BorrowId.Ord.compare id0 thresh.bid >= 0); + assert (V.BorrowId.Ord.compare id1 thresh.bid >= 0); + assert (not (V.BorrowId.InjSubst.mem id0 !bid_map)); + bid_map := V.BorrowId.InjSubst.add id0 id1 !bid_map; + id1 + in + let lookup_sid (id : V.SymbolicValueId.id) : V.SymbolicValueId.id = + match V.SymbolicValueId.InjSubst.find_opt id !sid_map with + | None -> id + | Some id -> id + in + assert (v0.V.ty = v1.V.ty); + match (v0.V.value, v1.V.value) with + | V.Primitive pv0, V.Primitive pv1 -> + if pv0 = pv1 then v1 else raise (Failure "Unimplemented") + | V.Adt av0, V.Adt av1 -> + if av0.variant_id = av1.variant_id then + let fields = List.combine av0.field_values av1.field_values in + let field_values = List.map (fun (f0, f1) -> match_rec f0 f1) fields in + let value : V.value = + V.Adt { variant_id = av0.variant_id; field_values } + in + { V.value; ty = v1.V.ty } + else ( + convertible := false; + (* For now, we don't merge values which contain borrows *) + (* TODO: *) + raise (Failure "Unimplemented")) + | Bottom, Bottom -> v1 + | Borrow bc0, Borrow bc1 -> + let bc = + match (bc0, bc1) with + | SharedBorrow (mv0, bid0), SharedBorrow (mv1, bid1) -> + let bid0 = lookup_bid bid0 in + (* Not completely sure what to do with the meta-value. If a shared + symbolic value gets expanded in a branch, it may be simplified + (by being folded back to a symbolic value) upon doing the join, + which as a result would lead to code where it is considered as + mutable (which is sound). On the other hand, if we access a + symbolic value in a loop, the translated loop should take it as + input anyway, so maybe this actually leads to equivalent + code. + *) + let mv = match_rec mv0 mv1 in + let bid = if bid0 = bid1 then bid1 else map_bid bid0 bid1 in + V.SharedBorrow (mv, bid) + | MutBorrow (bid0, bv0), MutBorrow (bid1, bv1) -> + let bid0 = lookup_bid bid0 in + let bv = match_rec bv0 bv1 in + let bid = if bid0 = bid1 then bid1 else map_bid bid0 bid1 in + V.MutBorrow (bid, bv) + | ReservedMutBorrow _, _ + | _, ReservedMutBorrow _ + | SharedBorrow _, MutBorrow _ + | MutBorrow _, SharedBorrow _ -> + (* If we get here, either there is a typing inconsistency, or we are + trying to match a reserved borrow, which shouldn't happen because + reserved borrow should be eliminated very quickly - they are introduced + just before function calls which activate them *) + raise (Failure "Unexpected") + in + { V.value = V.Borrow bc; V.ty = v1.V.ty } + | Loan lc0, Loan lc1 -> + (* TODO: maybe we should enforce that the ids are always exactly the same - + without matching *) + let lc = + match (lc0, lc1) with + | SharedLoan (ids0, sv0), SharedLoan (ids1, sv1) -> + let ids0 = lookup_bids ids0 in + (* Not sure what to do if the ids don't match *) + let ids = + if ids0 = ids1 then ids1 else raise (Failure "Unimplemented") + in + let sv = match_rec sv0 sv1 in + V.SharedLoan (ids, sv) + | MutLoan id0, MutLoan id1 -> + let id0 = lookup_bid id0 in + let id = if id0 = id1 then id1 else map_bid id0 id1 in + V.MutLoan id + | SharedLoan _, MutLoan _ | MutLoan _, SharedLoan _ -> + raise (Failure "Unreachable") + in + { V.value = Loan lc; ty = v1.V.ty } + | Symbolic sv0, Symbolic sv1 -> + (* TODO: id check mapping *) + let id0 = lookup_sid sv0.sv_id in + let id1 = sv1.sv_id in + if id0 = id1 then ( + assert (sv0.sv_kind = sv1.sv_kind); + (* Sanity check: the types should be the same *) + match_rtypes rid_map ctx sv0.sv_ty sv1.sv_ty; + (* Return *) + v1) + else ( + (* For now, we force all the symbolic values containing borrows to + be eagerly expanded, and we don't support nested borrows *) + assert (not (value_has_borrows ctx v0.V.value)); + assert (not (value_has_borrows ctx v1.V.value)); + raise (Failure "Unimplemented")) + | Loan lc, _ -> ( + match lc with + | SharedLoan (ids, _) -> raise (ValueMatchFailure (EndBorrowsInSrc ids)) + | MutLoan id -> raise (ValueMatchFailure (EndBorrowInSrc id))) + | _, Loan lc -> ( + match lc with + | SharedLoan (ids, _) -> raise (ValueMatchFailure (EndBorrowsInTgt ids)) + | MutLoan id -> raise (ValueMatchFailure (EndBorrowInTgt id))) + | Symbolic _, _ -> raise (Failure "Unimplemented") + | _, Symbolic _ -> raise (Failure "Unimplemented") + | _ -> raise (Failure "Unreachable") + +(*(** This function raises exceptions of kind {!ValueMatchFailue} *) + let rec match_typed_avalues (config : C.config) (thresh : cnt_thresholds) + (rid_map : T.RegionId.InjSubst.t ref) (bid_map : V.BorrowId.InjSubst.t ref) + (sid_map : V.SymbolicValueId.InjSubst.t ref) + (aid_map : V.AbstractionId.InjSubst.t ref) (ctx : C.eval_ctx) + (v0 : V.typed_avalue) (v1 : V.typed_avalue) : V.typed_avalue = + let match_rec = + match_typed_avalues config thresh rid_map bid_map sid_map ctx + in + (* TODO: factorize those helpers with [match_typed_values] (write a functor?) *) + let lookup_bid (id : V.BorrowId.id) : V.BorrowId.id = + V.BorrowId.InjSubst.find_with_default id id !bid_map + in + let lookup_bids (ids : V.BorrowId.Set.t) : V.BorrowId.Set.t = + V.BorrowId.Set.map lookup_bid ids + in + let map_bid (id0 : V.BorrowId.id) (id1 : V.BorrowId.id) : V.BorrowId.id = + assert (V.BorrowId.Ord.compare id0 thresh.bid >= 0); + assert (V.BorrowId.Ord.compare id1 thresh.bid >= 0); + assert (not (V.BorrowId.InjSubst.mem id0 !bid_map)); + bid_map := V.BorrowId.InjSubst.add id0 id1 !bid_map; + id1 + in + let lookup_sid (id : V.SymbolicValueId.id) : V.SymbolicValueId.id = + match V.SymbolicValueId.InjSubst.find_opt id !sid_map with + | None -> id + | Some id -> id + in + assert (v0.V.ty = v1.V.ty); + match (v0.V.value, v1.V.value) with + | V.APrimitive pv0, V.APrimitive pv1 -> + if pv0 = pv1 then v1 else raise (Failure "Unimplemented") + | V.AAdt av0, V.AAdt av1 -> + if av0.variant_id = av1.variant_id then + let fields = List.combine av0.field_values av1.field_values in + let field_values = List.map (fun (f0, f1) -> match_rec f0 f1) fields in + let value : V.value = + V.Adt { variant_id = av0.variant_id; field_values } + in + { V.value; ty = v1.V.ty } + else raise (Failure "Unimplemented") + | Bottom, Bottom -> v1 + | Borrow bc0, Borrow bc1 -> + let bc = + match (bc0, bc1) with + | SharedBorrow (mv0, bid0), SharedBorrow (mv1, bid1) -> + let bid0 = lookup_bid bid0 in + (* Not completely sure what to do with the meta-value. If a shared + symbolic value gets expanded in a branch, it may be simplified + (by being folded back to a symbolic value) upon doing the join, + which as a result would lead to code where it is considered as + mutable (which is sound). On the other hand, if we access a + symbolic value in a loop, the translated loop should take it as + input anyway, so maybe this actually leads to equivalent + code. + *) + let mv = match_rec mv0 mv1 in + let bid = if bid0 = bid1 then bid1 else map_bid bid0 bid1 in + V.SharedBorrow (mv, bid) + | MutBorrow (bid0, bv0), MutBorrow (bid1, bv1) -> + let bid0 = lookup_bid bid0 in + let bv = match_rec bv0 bv1 in + let bid = if bid0 = bid1 then bid1 else map_bid bid0 bid1 in + V.MutBorrow (bid, bv) + | ReservedMutBorrow _, _ + | _, ReservedMutBorrow _ + | SharedBorrow _, MutBorrow _ + | MutBorrow _, SharedBorrow _ -> + (* If we get here, either there is a typing inconsistency, or we are + trying to match a reserved borrow, which shouldn't happen because + reserved borrow should be eliminated very quickly - they are introduced + just before function calls which activate them *) + raise (Failure "Unexpected") + in + { V.value = V.Borrow bc; V.ty = v1.V.ty } + | Loan lc0, Loan lc1 -> + (* TODO: maybe we should enforce that the ids are always exactly the same - + without matching *) + let lc = + match (lc0, lc1) with + | SharedLoan (ids0, sv0), SharedLoan (ids1, sv1) -> + let ids0 = lookup_bids ids0 in + (* Not sure what to do if the ids don't match *) + let ids = + if ids0 = ids1 then ids1 else raise (Failure "Unimplemented") + in + let sv = match_rec sv0 sv1 in + V.SharedLoan (ids, sv) + | MutLoan id0, MutLoan id1 -> + let id0 = lookup_bid id0 in + let id = if id0 = id1 then id1 else map_bid id0 id1 in + V.MutLoan id + | SharedLoan _, MutLoan _ | MutLoan _, SharedLoan _ -> + raise (Failure "Unreachable") + in + { V.value = Loan lc; ty = v1.V.ty } + | Symbolic sv0, Symbolic sv1 -> + (* TODO: id check mapping *) + let id0 = lookup_sid sv0.sv_id in + let id1 = sv1.sv_id in + if id0 = id1 then ( + assert (sv0.sv_kind = sv1.sv_kind); + (* Sanity check: the types should be the same *) + match_rtypes rid_map ctx sv0.sv_ty sv1.sv_ty; + (* Return *) + v1) + else ( + (* For now, we force all the symbolic values containing borrows to + be eagerly expanded, and we don't support nested borrows *) + assert (not (value_has_borrows ctx v0.V.value)); + assert (not (value_has_borrows ctx v1.V.value)); + raise (Failure "Unimplemented")) + | Loan lc, _ -> ( + match lc with + | SharedLoan (ids, _) -> raise (ValueMatchFailure (EndBorrowsInSrc ids)) + | MutLoan id -> raise (ValueMatchFailure (EndBorrowInSrc id))) + | _, Loan lc -> ( + match lc with + | SharedLoan (ids, _) -> raise (ValueMatchFailure (EndBorrowsInTgt ids)) + | MutLoan id -> raise (ValueMatchFailure (EndBorrowInTgt id))) + | _ -> raise (Failure "Unreachable")*) + +(** Apply substitutions in the first abstraction, then merge the abstractions together. *) +let subst_merge_abstractions (loop_id : V.LoopId.id) (thresh : cnt_thresholds) + (rid_map : T.RegionId.InjSubst.t) (bid_map : V.BorrowId.InjSubst.t) + (sid_map : V.SymbolicValueId.InjSubst.t) (ctx : C.eval_ctx) (abs0 : V.abs) + (abs1 : V.abs) : V.abs = + (* Apply the substitutions in the first abstraction *) + let rsubst id = + assert (T.RegionId.Ord.compare id thresh.rid >= 0); + T.RegionId.InjSubst.find_with_default id id rid_map + in + let rvsubst id = id in + let tsubst id = id in + let ssubst id = + assert (V.SymbolicValueId.Ord.compare id thresh.sid >= 0); + V.SymbolicValueId.InjSubst.find_with_default id id sid_map + in + let bsubst id = + assert (V.BorrowId.Ord.compare id thresh.bid >= 0); + V.BorrowId.InjSubst.find_with_default id id bid_map + in + let asubst id = id in + let abs0 = + Substitute.abs_subst_ids rsubst rvsubst tsubst ssubst bsubst asubst abs0 + in -type joined_ctx_or_update = - | EndAbs of V.AbstractionId.id - | EndBorrow of V.BorrowId.id - | JoinedCtx of match_ctx + (* Merge the two abstractions *) + let abs_kind = V.Loop loop_id in + let can_end = false in + merge_abstractions abs_kind can_end ctx abs0 abs1 (** Merge a borrow with the abstraction containing the associated loan, where the abstraction must be a *loop abstraction* (we don't synthesize code during @@ -306,8 +638,10 @@ let rec merge_borrows_with_parent_loop_abs (config : C.config) (* TODO: we probably don't need an [match_ctx], and actually we might not use the bounds propertly. + TODO: remove *) -let match_ctx_with_target (config : C.config) (tgt_mctx : match_ctx) : cm_fun = +let match_ctx_with_target_old (config : C.config) (tgt_mctx : match_ctx) : + cm_fun = fun cf src_ctx -> (* We first reorganize [ctx] so that we can match it with [tgt_mctx] *) (* First, collect all the borrows and abstractions which are in [ctx] @@ -351,12 +685,12 @@ let loop_join_entry_ctx_with_continue_ctx (ctx0 : match_ctx) (ctx1 : C.eval_ctx) let rec loop_join_entry_ctx_with_continue_ctxs (ctx0 : match_ctx) (ctxs : C.eval_ctx list) : joined_ctx_or_update = match ctxs with - | [] -> JoinedCtx ctx0 + | [] -> Ok ctx0 | ctx1 :: ctxs -> ( let res = loop_join_entry_ctx_with_continue_ctx ctx0 ctx1 in match res with - | EndAbs _ | EndBorrow _ -> res - | JoinedCtx ctx0 -> loop_join_entry_ctx_with_continue_ctxs ctx0 ctxs) + | Error _ -> res + | Ok ctx0 -> loop_join_entry_ctx_with_continue_ctxs ctx0 ctxs) let compute_loop_entry_fixed_point (config : C.config) (eval_loop_body : st_cm_fun) (ctx0 : C.eval_ctx) : match_ctx = @@ -416,17 +750,25 @@ let compute_loop_entry_fixed_point (config : C.config) (* Check if the join succeeded, or if we need to end abstractions/borrows in the original environment first *) match join_ctxs mctx with - | EndAbs id -> + | Error (EndAbsInSrc id) -> let ctx1 = InterpreterBorrows.end_abstraction_no_synth config id mctx.ctx in eval_iteration_then_join { mctx with ctx = ctx1 } - | EndBorrow id -> + | Error (EndBorrowInSrc id) -> let ctx1 = InterpreterBorrows.end_borrow_no_synth config id mctx.ctx in eval_iteration_then_join { mctx with ctx = ctx1 } - | JoinedCtx mctx1 -> + | Error (EndBorrowsInSrc ids) -> + let ctx1 = + InterpreterBorrows.end_borrows_no_synth config ids mctx.ctx + in + eval_iteration_then_join { mctx with ctx = ctx1 } + | Error (EndBorrowInTgt _ | EndBorrowsInTgt _) -> + (* Shouldn't happen here *) + raise (Failure "Unreachable") + | Ok mctx1 -> (* The join succeeded: check if we reached a fixed point, otherwise iterate *) if equiv_ctxs mctx mctx1 then mctx1 @@ -487,7 +829,7 @@ let eval_loop_symbolic (config : C.config) (eval_loop_body : st_cm_fun) : (* Compute the fixed point at the loop entrance *) let mctx = compute_loop_entry_fixed_point config eval_loop_body ctx in (* Synthesize the end of the function *) - let end_expr = match_ctx_with_target config mctx (cf EndEnterLoop) ctx in + let end_expr = match_ctx_with_target_old config mctx (cf EndEnterLoop) ctx in (* Synthesize the loop body by evaluating it, with the continuation for after the loop starting at the *fixed point*, but with a special treatment for the [Break] and [Continue] cases *) @@ -502,7 +844,7 @@ let eval_loop_symbolic (config : C.config) (eval_loop_body : st_cm_fun) : | Continue i -> (* We don't support nested loops for now *) assert (i = 0); - match_ctx_with_target config mctx (cf EndContinue) ctx + match_ctx_with_target_old config mctx (cf EndContinue) ctx | Unit | EndEnterLoop | EndContinue -> (* For why we can't get [Unit], see the comments inside {!eval_loop_concrete}. For [EndEnterLoop] and [EndContinue]: we don't support nested loops for now. diff --git a/compiler/Substitute.ml b/compiler/Substitute.ml index eb61f076..0978e078 100644 --- a/compiler/Substitute.ml +++ b/compiler/Substitute.ml @@ -9,30 +9,23 @@ module E = Expressions module A = LlbcAst module C = Contexts -(** Substitute types variables and regions in a type. - - TODO: we can reimplement that with visitors. - *) -let rec ty_substitute (rsubst : 'r1 -> 'r2) - (tsubst : T.TypeVarId.id -> 'r2 T.ty) (ty : 'r1 T.ty) : 'r2 T.ty = +(** Substitute types variables and regions in a type. *) +let ty_substitute (rsubst : 'r1 -> 'r2) (tsubst : T.TypeVarId.id -> 'r2 T.ty) + (ty : 'r1 T.ty) : 'r2 T.ty = let open T in - let subst = ty_substitute rsubst tsubst in - (* helper *) - match ty with - | Adt (def_id, regions, tys) -> - Adt (def_id, List.map rsubst regions, List.map subst tys) - | Array aty -> Array (subst aty) - | Slice sty -> Slice (subst sty) - | Ref (r, ref_ty, ref_kind) -> Ref (rsubst r, subst ref_ty, ref_kind) - (* Below variants: we technically return the same value, but because - one has type ['r1 ty] and the other has type ['r2 ty], we need to - deconstruct then reconstruct *) - | Bool -> Bool - | Char -> Char - | Never -> Never - | Integer int_ty -> Integer int_ty - | Str -> Str - | TypeVar vid -> tsubst vid + let visitor = + object + inherit [_] map_ty + method visit_'r _ r = rsubst r + method! visit_TypeVar _ id = tsubst id + + method! visit_type_var_id _ _ = + (* We should never get here because we reimplemented [visit_TypeVar] *) + raise (Failure "Unexpected") + end + in + + visitor#visit_ty () ty (** Convert an {!T.rty} to an {!T.ety} by erasing the region variables *) let erase_regions (ty : T.rty) : T.ety = @@ -360,3 +353,107 @@ let substitute_signature (asubst : T.RegionGroupId.id -> V.AbstractionId.id) in let regions_hierarchy = List.map subst_region_group sg.A.regions_hierarchy in { A.regions_hierarchy; inputs; output } + +(** Substitute identifiers in a type *) +let ty_substitute_ids (tsubst : T.TypeVarId.id -> T.TypeVarId.id) (ty : 'r T.ty) + : 'r T.ty = + let open T in + let visitor = + object + inherit [_] map_ty + method visit_'r _ r = r + method! visit_type_var_id _ id = tsubst id + end + in + + visitor#visit_ty () ty + +(* This visitor is a mess... + + We want to write a class which visits abstractions, values, etc. *and their + types* to substitute identifiers. + + The issue comes is that we derive two specialized types (ety and rty) from a + polymorphic type (ty). Because of this, there are typing issues with + [visit_'r] if we define a class which visits objects of types [ety] and [rty] + while inheriting a class which visit [ty]... +*) +let subst_ids_visitor (rsubst : T.RegionId.id -> T.RegionId.id) + (rvsubst : T.RegionVarId.id -> T.RegionVarId.id) + (tsubst : T.TypeVarId.id -> T.TypeVarId.id) + (ssubst : V.SymbolicValueId.id -> V.SymbolicValueId.id) + (bsubst : V.BorrowId.id -> V.BorrowId.id) + (asubst : V.AbstractionId.id -> V.AbstractionId.id) = + let subst_rty = + object + inherit [_] T.map_ty + + method visit_'r _ r = + match r with T.Static -> T.Static | T.Var rid -> T.Var (rsubst rid) + + method! visit_type_var_id _ id = tsubst id + end + in + + let visitor = + object (self : 'self) + inherit [_] V.map_abs + method! visit_borrow_id _ bid = bsubst bid + method! visit_loan_id _ bid = bsubst bid + + method! visit_symbolic_value env sv = + let sv_id = ssubst sv.sv_id in + let sv_ty = subst_rty#visit_ty env sv.sv_ty in + { sv with V.sv_id; sv_ty } + + method! visit_ety _ ty = ty_substitute_ids tsubst ty + + (** We *do* visit meta-values *) + method! visit_mvalue env v = self#visit_typed_value env v + + method! visit_region_id _ id = rsubst id + method! visit_region_var_id _ id = rvsubst id + method! visit_abstraction_id _ id = asubst id + end + in + + object + method visit_ety (x : T.ety) : T.ety = visitor#visit_ety () x + method visit_rty (x : T.rty) : T.rty = visitor#visit_rty () x + + method visit_typed_value (x : V.typed_value) : V.typed_value = + visitor#visit_typed_value () x + + method visit_typed_avalue (x : V.typed_avalue) : V.typed_avalue = + visitor#visit_typed_avalue () x + + method visit_abs (x : V.abs) : V.abs = visitor#visit_abs () x + end + +let typed_value_subst_ids (rsubst : T.RegionId.id -> T.RegionId.id) + (rvsubst : T.RegionVarId.id -> T.RegionVarId.id) + (tsubst : T.TypeVarId.id -> T.TypeVarId.id) + (ssubst : V.SymbolicValueId.id -> V.SymbolicValueId.id) + (bsubst : V.BorrowId.id -> V.BorrowId.id) (v : V.typed_value) : + V.typed_value = + let asubst _ = raise (Failure "Unreachable") in + (subst_ids_visitor rsubst rvsubst tsubst ssubst bsubst asubst) + #visit_typed_value v + +let typed_avalue_subst_ids (rsubst : T.RegionId.id -> T.RegionId.id) + (rvsubst : T.RegionVarId.id -> T.RegionVarId.id) + (tsubst : T.TypeVarId.id -> T.TypeVarId.id) + (ssubst : V.SymbolicValueId.id -> V.SymbolicValueId.id) + (bsubst : V.BorrowId.id -> V.BorrowId.id) (v : V.typed_avalue) : + V.typed_avalue = + let asubst _ = raise (Failure "Unreachable") in + (subst_ids_visitor rsubst rvsubst tsubst ssubst bsubst asubst) + #visit_typed_avalue v + +let abs_subst_ids (rsubst : T.RegionId.id -> T.RegionId.id) + (rvsubst : T.RegionVarId.id -> T.RegionVarId.id) + (tsubst : T.TypeVarId.id -> T.TypeVarId.id) + (ssubst : V.SymbolicValueId.id -> V.SymbolicValueId.id) + (bsubst : V.BorrowId.id -> V.BorrowId.id) + (asubst : V.AbstractionId.id -> V.AbstractionId.id) (x : V.abs) : V.abs = + (subst_ids_visitor rsubst rvsubst tsubst ssubst bsubst asubst)#visit_abs x diff --git a/compiler/Values.ml b/compiler/Values.ml index 86cb9098..f206f65a 100644 --- a/compiler/Values.ml +++ b/compiler/Values.ml @@ -12,9 +12,9 @@ module AbstractionId = IdGen () module FunCallId = IdGen () module LoopId = IdGen () -type big_int = PrimitiveValues.big_int [@@deriving show] -type scalar_value = PrimitiveValues.scalar_value [@@deriving show] -type primitive_value = PrimitiveValues.primitive_value [@@deriving show] +type big_int = PrimitiveValues.big_int [@@deriving show, ord] +type scalar_value = PrimitiveValues.scalar_value [@@deriving show, ord] +type primitive_value = PrimitiveValues.primitive_value [@@deriving show, ord] (** The kind of a symbolic value, which precises how the value was generated. @@ -45,7 +45,7 @@ type sv_kind = | LoopOutput (** The output of a loop (seen as a forward computation) *) | LoopGivenBack (** A value given back by a loop (when ending abstractions while going backwards) *) -[@@deriving show] +[@@deriving show, ord] (** A symbolic value *) type symbolic_value = { @@ -53,12 +53,12 @@ type symbolic_value = { sv_id : SymbolicValueId.id; sv_ty : rty; } -[@@deriving show] +[@@deriving show, ord] -type borrow_id = BorrowId.id [@@deriving show] -type borrow_id_set = BorrowId.Set.t [@@deriving show] -type loan_id = BorrowId.id [@@deriving show] -type loan_id_set = BorrowId.Set.t [@@deriving show] +type borrow_id = BorrowId.id [@@deriving show, ord] +type borrow_id_set = BorrowId.Set.t [@@deriving show, ord] +type loan_id = BorrowId.id [@@deriving show, ord] +type loan_id_set = BorrowId.Set.t [@@deriving show, ord] (** Ancestor for {!typed_value} iter visitor *) class ['self] iter_typed_value_base = @@ -70,6 +70,7 @@ class ['self] iter_typed_value_base = method visit_erased_region : 'env -> erased_region -> unit = fun _ _ -> () method visit_symbolic_value : 'env -> symbolic_value -> unit = fun _ _ -> () + method visit_variant_id : 'env -> variant_id -> unit = fun _ _ -> () method visit_ety : 'env -> ety -> unit = fun _ _ -> () method visit_borrow_id : 'env -> borrow_id -> unit = fun _ _ -> () method visit_loan_id : 'env -> loan_id -> unit = fun _ _ -> () @@ -96,6 +97,7 @@ class ['self] map_typed_value_base = fun _ sv -> sv method visit_ety : 'env -> ety -> ety = fun _ ty -> ty + method visit_variant_id : 'env -> variant_id -> variant_id = fun _ x -> x method visit_borrow_id : 'env -> borrow_id -> borrow_id = fun _ id -> id method visit_loan_id : 'env -> loan_id -> loan_id = fun _ id -> id @@ -122,7 +124,7 @@ type value = *) and adt_value = { - variant_id : (VariantId.id option[@opaque]); + variant_id : variant_id option; field_values : typed_value list; } @@ -202,6 +204,7 @@ and mvalue = typed_value and typed_value = { value : value; ty : ety } [@@deriving show, + ord, visitors { name = "iter_typed_value_visit_mvalue"; @@ -226,7 +229,7 @@ and typed_value = { value : value; ty : ety } TODO: we may want to create wrappers, to prevent mixing meta values and regular values. *) -type msymbolic_value = symbolic_value [@@deriving show] +type msymbolic_value = symbolic_value [@@deriving show, ord] class ['self] iter_typed_value = object (_self : 'self) @@ -275,6 +278,7 @@ type abstract_shared_borrow = | AsbProjReborrows of symbolic_value * rty [@@deriving show, + ord, visitors { name = "iter_abstract_shared_borrow"; @@ -296,6 +300,7 @@ type abstract_shared_borrow = type abstract_shared_borrows = abstract_shared_borrow list [@@deriving show, + ord, visitors { name = "iter_abstract_shared_borrows"; @@ -383,6 +388,7 @@ type aproj = | AIgnoredProjBorrows [@@deriving show, + ord, visitors { name = "iter_aproj"; @@ -400,25 +406,61 @@ type aproj = concrete = true; }] -type region = RegionVarId.id Types.region [@@deriving show] -type abstraction_id = AbstractionId.id [@@deriving show] +type region = RegionVarId.id Types.region [@@deriving show, ord] +type region_var_id = RegionVarId.id [@@deriving show, ord] +type region_id = RegionId.id [@@deriving show, ord] +type region_id_set = RegionId.Set.t [@@deriving show, ord] +type abstraction_id = AbstractionId.id [@@deriving show, ord] +type abstraction_id_set = AbstractionId.Set.t [@@deriving show, ord] (** Ancestor for {!typed_avalue} iter visitor *) class ['self] iter_typed_avalue_base = - object (_self : 'self) + object (self : 'self) inherit [_] iter_aproj - method visit_region : 'env -> region -> unit = fun _ _ -> () + method visit_region_var_id : 'env -> region_var_id -> unit = fun _ _ -> () + + method visit_region : 'env -> region -> unit = + fun env r -> + match r with + | Static -> () + | Var rid -> self#visit_region_var_id env rid + + method visit_region_id : 'env -> region_id -> unit = fun _ _ -> () + + method visit_region_id_set : 'env -> region_id_set -> unit = + fun env ids -> RegionId.Set.iter (self#visit_region_id env) ids + method visit_abstraction_id : 'env -> abstraction_id -> unit = fun _ _ -> () + + method visit_abstraction_id_set : 'env -> abstraction_id_set -> unit = + fun env ids -> AbstractionId.Set.iter (self#visit_abstraction_id env) ids end (** Ancestor for {!typed_avalue} map visitor *) class ['self] map_typed_avalue_base = - object (_self : 'self) + object (self : 'self) inherit [_] map_aproj - method visit_region : 'env -> region -> region = fun _ r -> r + + method visit_region_var_id : 'env -> region_var_id -> region_var_id = + fun _ x -> x + + method visit_region : 'env -> region -> region = + fun env r -> + match r with + | Static -> Static + | Var rid -> Var (self#visit_region_var_id env rid) + + method visit_region_id : 'env -> region_id -> region_id = fun _ x -> x + + method visit_region_id_set : 'env -> region_id_set -> region_id_set = + fun env ids -> RegionId.Set.map (self#visit_region_id env) ids method visit_abstraction_id : 'env -> abstraction_id -> abstraction_id = fun _ x -> x + + method visit_abstraction_id_set + : 'env -> abstraction_id_set -> abstraction_id_set = + fun env ids -> AbstractionId.Set.map (self#visit_abstraction_id env) ids end (** Abstraction values are used inside of abstractions to properly model @@ -774,6 +816,7 @@ and aborrow_content = and typed_avalue = { value : avalue; ty : rty } [@@deriving show, + ord, visitors { name = "iter_typed_avalue"; @@ -819,7 +862,7 @@ type abs_kind = See the explanations for [SynthInput]. *) | Loop of LoopId.id (** The abstraction corresponds to a loop *) -[@@deriving show] +[@@deriving show, ord] (** Abstractions model the parts in the borrow graph where the borrowing relations have been abstracted because of a function call. @@ -841,17 +884,18 @@ type abs = { don't need to end the return region for 'b (if it is the case, it means the function doesn't borrow check). *) - parents : (AbstractionId.Set.t[@opaque]); (** The parent abstractions *) - original_parents : (AbstractionId.id list[@opaque]); + parents : abstraction_id_set; (** The parent abstractions *) + original_parents : abstraction_id list; (** The original list of parents, ordered. This is used for synthesis. TODO: remove? *) - regions : (RegionId.Set.t[@opaque]); (** Regions owned by this abstraction *) - ancestors_regions : (RegionId.Set.t[@opaque]); + regions : region_id_set; (** Regions owned by this abstraction *) + ancestors_regions : region_id_set; (** Union of the regions owned by this abstraction's ancestors (not including the regions of this abstraction itself) *) avalues : typed_avalue list; (** The values in this abstraction *) } [@@deriving show, + ord, visitors { name = "iter_abs"; |