diff options
author | Son Ho | 2022-01-25 18:49:11 +0100 |
---|---|---|
committer | Son Ho | 2022-01-25 18:49:11 +0100 |
commit | 42da25ddae3deb8df125ca5d1963a0b40d683c2a (patch) | |
tree | 0cb94d1bf5537162bfc70ef0fd6685fa3ca2bc26 /src/SymbolicToPure.ml | |
parent | c9d8b44983e6111615400b7f2891e8f928009cd3 (diff) |
Make progress on SymbolicToPure.translate_end_abstraction
Diffstat (limited to '')
-rw-r--r-- | src/SymbolicToPure.ml | 192 |
1 files changed, 152 insertions, 40 deletions
diff --git a/src/SymbolicToPure.ml b/src/SymbolicToPure.ml index 11d7a657..97d9baf1 100644 --- a/src/SymbolicToPure.ml +++ b/src/SymbolicToPure.ml @@ -99,10 +99,14 @@ let translate_type_name (def : T.type_def) : Id.name = def.T.name type type_context = { types_infos : TA.type_infos; + cfim_type_defs : T.type_def TypeDefId.Map.t; type_defs : type_def TypeDefId.Map.t; } -type fun_context = { fun_defs : fun_def FunDefId.Map.t } +type fun_context = { + cfim_fun_defs : A.fun_def FunDefId.Map.t; + fun_defs : fun_def FunDefId.Map.t; +} (* TODO: do we really need that actually? *) type synth_ctx = { @@ -113,15 +117,57 @@ type synth_ctx = { declarations : M.declaration_group list; } +type call_info = { + forward : S.call; + backwards : V.abs T.RegionGroupId.Map.t; + (** TODO: not sure we need this anymore *) +} +(** Whenever we translate a function call or an ended abstraction, we + store the related information (this is useful when translating ended + children abstractions) + *) + type bs_ctx = { type_context : type_context; + fun_context : fun_context; fun_def : A.fun_def; bid : T.RegionGroupId.id option; + calls : call_info V.FunCallId.Map.t; + abstractions : V.abs V.AbstractionId.Map.t; } (** Body synthesis context *) -let bs_ctx_lookup_type_def (id : TypeDefId.id) (ctx : bs_ctx) : type_def = - TypeDefId.Map.find id ctx.type_context.type_defs +(*let bs_ctx_lookup_type_def (id : TypeDefId.id) (ctx : bs_ctx) : type_def = + TypeDefId.Map.find id ctx.type_context.type_defs*) +let bs_ctx_lookup_cfim_type_def (id : TypeDefId.id) (ctx : bs_ctx) : T.type_def + = + TypeDefId.Map.find id ctx.type_context.cfim_type_defs + +let bs_ctx_lookup_cfim_fun_def (id : FunDefId.id) (ctx : bs_ctx) : A.fun_def = + FunDefId.Map.find id ctx.fun_context.cfim_fun_defs + +let bs_ctx_register_forward_call (call_id : V.FunCallId.id) (forward : S.call) + (ctx : bs_ctx) : bs_ctx = + let calls = ctx.calls in + assert (not (V.FunCallId.Map.mem call_id calls)); + let info = { forward; backwards = T.RegionGroupId.Map.empty } in + let calls = V.FunCallId.Map.add call_id info calls in + { ctx with calls } + +let bs_ctx_register_backward_call (abs : V.abs) (ctx : bs_ctx) : bs_ctx = + (* Insert the abstraction in the call informations *) + let back_id = Option.get abs.back_id in + let info = V.FunCallId.Map.find abs.call_id ctx.calls in + assert (not (T.RegionGroupId.Map.mem back_id info.backwards)); + let backwards = T.RegionGroupId.Map.add back_id abs info.backwards in + let info = { info with backwards } in + let calls = V.FunCallId.Map.add abs.call_id info ctx.calls in + (* Insert the abstraction in the abstractions map *) + let abstractions = ctx.abstractions in + assert (not (V.AbstractionId.Map.mem abs.abs_id abstractions)); + let abstractions = V.AbstractionId.Map.add abs.abs_id abs abstractions in + (* Update the context *) + { ctx with calls; abstractions } let rec translate_sty (ty : T.sty) : ty = let translate = translate_sty in @@ -256,6 +302,8 @@ let rec translate_back_ty (ctx : bs_ctx) (keep_region : 'r -> bool) (** Small utility: list the transitive parents of a region var group. We don't do that in an efficient manner, but it doesn't matter. + + TODO: remove? *) let rec list_parent_region_groups (def : A.fun_def) (gid : T.RegionGroupId.id) : T.RegionGroupId.Set.t = @@ -273,7 +321,10 @@ let rec list_parent_region_groups (def : A.fun_def) (gid : T.RegionGroupId.id) : in parents -(** Small utility: same as [list_parent_region_groups], but returns an ordered list *) +(** Small utility: same as [list_parent_region_groups], but returns an ordered list. + + TODO: remove? + *) let list_ordered_parent_region_groups (def : A.fun_def) (gid : T.RegionGroupId.id) : T.RegionGroupId.id list = let pset = list_parent_region_groups def gid in @@ -285,6 +336,30 @@ let list_ordered_parent_region_groups (def : A.fun_def) let parents = List.map (fun (rg : T.region_var_group) -> rg.id) parents in parents +let list_ancestor_abstractions_ids (ctx : bs_ctx) (abs : V.abs) : + V.AbstractionId.id list = + (* We could do something more "elegant" without references, but it is + * so much simpler to use references... *) + let abs_set = ref V.AbstractionId.Set.empty in + let rec gather (abs_id : V.AbstractionId.id) : unit = + if V.AbstractionId.Set.mem abs_id !abs_set then () + else ( + abs_set := V.AbstractionId.Set.add abs_id !abs_set; + let abs = V.AbstractionId.Map.find abs_id ctx.abstractions in + List.iter gather abs.original_parents) + in + gather abs.abs_id; + let ids = !abs_set in + (* List the ancestors, in the proper order *) + let call_info = V.FunCallId.Map.find abs.call_id ctx.calls in + List.filter + (fun id -> V.AbstractionId.Set.mem id ids) + call_info.forward.abstractions + +let list_ancestor_abstractions (ctx : bs_ctx) (abs : V.abs) : V.abs list = + let abs_ids = list_ancestor_abstractions_ids ctx abs in + List.map (fun id -> V.AbstractionId.Map.find id ctx.abstractions) abs_ids + let translate_fun_sig (ctx : bs_ctx) (def : A.fun_def) (bid : T.RegionGroupId.id option) : fun_sig = let sg = def.signature in @@ -390,31 +465,38 @@ let get_var_for_symbolic_value (sv : V.symbolic_value) (ctx : bs_ctx) : var = let mk_place_from_var (v : var) : place = { var = v.id; projection = [] } (* TODO: move *) -let type_def_is_enum (def : type_def) : bool = - match def.kind with Struct _ -> false | Enum _ -> true +let type_def_is_enum (def : T.type_def) : bool = + match def.kind with T.Struct _ -> false | Enum _ -> true let typed_avalue_to_consumed (ctx : bs_ctx) (av : V.typed_avalue) : typed_value option = raise Unimplemented -let typed_avalue_to_given_back (ctx : bs_ctx) (av : V.typed_avalue) : - typed_value option = - raise Unimplemented - let abs_to_consumed (ctx : bs_ctx) (abs : V.abs) : typed_value list = List.filter_map (typed_avalue_to_consumed ctx) abs.avalues -let abs_to_given_back (ctx : bs_ctx) (abs : V.abs) : typed_value list = - List.filter_map (typed_avalue_to_given_back ctx) abs.avalues +let typed_avalue_to_given_back (av : V.typed_avalue) (ctx : bs_ctx) : + bs_ctx * lvalue option = + raise Unimplemented + +let abs_to_given_back (abs : V.abs) (ctx : bs_ctx) : bs_ctx * lvalue list = + let ctx, values = + List.fold_left_map + (fun ctx av -> typed_avalue_to_given_back av ctx) + ctx abs.avalues + in + let values = List.filter_map (fun x -> x) values in + (ctx, values) (** Return the ordered list of the (transitive) parents of a given abstraction. Is used for instance when collecting the input values given to all the parent functions, in order to properly instantiate an *) -let get_abs_ordered_parents (ctx : bs_ctx) (call_id : S.call_id) - (gid : T.RegionGroupId.id) : S.call * V.abs list = - raise Unimplemented +let get_abs_ancestors (ctx : bs_ctx) (abs : V.abs) : S.call * V.abs list = + let call_info = V.FunCallId.Map.find abs.call_id ctx.calls in + let abs_ancestors = list_ancestor_abstractions ctx abs in + (call_info.forward, abs_ancestors) let rec translate_expression (e : S.expression) (ctx : bs_ctx) : expression = match e with @@ -422,46 +504,71 @@ let rec translate_expression (e : S.expression) (ctx : bs_ctx) : expression = let v = translate_typed_value ctx v in Return (Value v) | Panic -> Panic - | FunCall (call, e) -> - (* Translate the function call *) - let type_params = List.map (translate_fwd_ty ctx) call.type_params in - let args = List.map (translate_typed_value ctx) call.args in - let args = List.map (fun v -> Value v) args in - let ctx, dest = fresh_var_for_symbolic_value call.dest ctx in - let func = - match call.call_id with - | S.Fun (A.Local fid, _) -> Local (fid, None) - | S.Fun (A.Assumed fid, _) -> Assumed fid - | S.Unop unop -> Unop unop - | S.Binop binop -> Binop binop - in - let call = { func; type_params; args } in - (* Translate the next expression *) - let e = translate_expression e ctx in - (* Put together *) - Let (Call ([ Var dest ], call), e) - | EndAbstraction (abs, e) -> translate_end_abstraction abs e + | FunCall (call, e) -> translate_function_call call e ctx + | EndAbstraction (abs, e) -> translate_end_abstraction abs e ctx | Expansion (sv, exp) -> translate_expansion sv exp ctx | Meta (_, e) -> (* We ignore the meta information *) translate_expression e ctx -and translate_end_abstraction (abs : V.abs) (e : S.expression) : expression = +and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) : + expression = + (* Translate the function call *) + let type_params = List.map (translate_fwd_ty ctx) call.type_params in + let args = List.map (translate_typed_value ctx) call.args in + let args = List.map (fun v -> Value v) args in + let ctx, dest = fresh_var_for_symbolic_value call.dest ctx in + (* Retrieve the function id, and register the function call in the context + * if necessary *) + let ctx, func = + match call.call_id with + | S.Fun (fid, call_id) -> + let ctx = bs_ctx_register_forward_call call_id call ctx in + let func = + match fid with + | A.Local fid -> Local (fid, None) + | A.Assumed fid -> Assumed fid + in + (ctx, func) + | S.Unop unop -> (ctx, Unop unop) + | S.Binop binop -> (ctx, Binop binop) + in + let call = { func; type_params; args } in + (* Translate the next expression *) + let e = translate_expression e ctx in + (* Put together *) + Let (Call ([ Var dest ], call), e) + +and translate_end_abstraction (abs : V.abs) (e : S.expression) (ctx : bs_ctx) : + expression = match abs.kind with | V.SynthInput -> (* There are no nested borrows for now: we shouldn't get there *) raise Unimplemented | V.FunCall -> (* Retrive the orignal call and the parent abstractions *) + let forward, backwards = get_abs_ancestors ctx abs in (* Retrieve the values consumed when we called the forward function and * ended the parent backward functions: those give us part of the input * values *) + let fwd_inputs = List.map (translate_typed_value ctx) forward.args in + let back_ancestors_inputs = + List.concat (List.map abs_to_consumed backwards) + in (* Retrieve the values consumed upon ending the loans inside this * abstraction: those give us the remaining input values *) + let back_inputs = abs_to_consumed abs in + let inputs = + List.concat [ fwd_inputs; back_ancestors_inputs; back_inputs ] + in (* Retrieve the values given back by this function: those are the output * values *) + let ctx, outputs = abs_to_given_back abs ctx in + (* Translate the next expression *) + let e = translate_expression e ctx in (* Put everything together *) - raise Unimplemented + let call = { func; type_params; args = inputs } in + Let (Call (outputs, call), e) | V.SynthRet -> (* *) raise Unimplemented @@ -499,7 +606,7 @@ and translate_expansion (sv : V.symbolic_value) (exp : S.expansion) match type_id with | T.AdtId adt_id -> (* Detect if this is an enumeration or not *) - let tdef = bs_ctx_lookup_type_def adt_id ctx in + let tdef = bs_ctx_lookup_cfim_type_def adt_id ctx in let is_enum = type_def_is_enum tdef in if is_enum then (* This is an enumeration: introduce an [ExpandEnum] let-binding *) @@ -568,9 +675,14 @@ and translate_expansion (sv : V.symbolic_value) (exp : S.expansion) let otherwise = translate_expression otherwise ctx in Switch (scrutinee, SwitchInt (int_ty, branches, otherwise)) -let translate_fun_def (type_context : type_context) (def : A.fun_def) - (bid : T.RegionGroupId.id option) (body : S.expression) : fun_def = - let bs_ctx = { type_context; fun_def = def; bid } in +let translate_fun_def (type_context : type_context) (fun_context : fun_context) + (def : A.fun_def) (bid : T.RegionGroupId.id option) (body : S.expression) : + fun_def = + let calls = V.FunCallId.Map.empty in + let abstractions = V.AbstractionId.Map.empty in + let bs_ctx = + { type_context; fun_context; fun_def = def; bid; calls; abstractions } + in (* Translate the function *) let def_id = def.A.def_id in let name = translate_fun_name def bid in |