From 99c93401c85b61ac2d254216b0b34884f44b1eff Mon Sep 17 00:00:00 2001 From: Son Ho Date: Thu, 25 Nov 2021 16:29:01 +0100 Subject: Start working on set_discriminant and factorize a bit expand_bottom_value --- src/Interpreter.ml | 124 +++++++++++++++++++++++++++++++++++++---------------- 1 file changed, 88 insertions(+), 36 deletions(-) (limited to 'src') diff --git a/src/Interpreter.ml b/src/Interpreter.ml index 2acabd4f..bc2209c3 100644 --- a/src/Interpreter.ml +++ b/src/Interpreter.ml @@ -1057,6 +1057,47 @@ let write_place_unwrap (config : C.config) (access : access_kind) (p : E.place) | Error _ -> failwith "Unreachable" | Ok env -> env +(** Compute an expanded ADT bottom value *) +let compute_expanded_bottom_adt_value (tyctx : T.type_def T.TypeDefId.vector) + (def_id : T.TypeDefId.id) (opt_variant_id : T.VariantId.id option) + (regions : T.erased_region list) (types : T.ety list) : V.typed_value = + (* Lookup the definition and check if it is an enumeration - it + should be an enumeration if and only if the projection element + is a field projection with *some* variant id. Retrieve the list + of fields at the same time. *) + let def = T.TypeDefId.nth tyctx def_id in + assert (List.length regions = T.RegionVarId.length def.T.region_params); + (* Compute the field types *) + let field_types = + Subst.type_def_get_instantiated_field_type def opt_variant_id types + in + (* Initialize the expanded value *) + let field_types = T.FieldId.vector_to_list field_types in + let fields = List.map (fun ty -> { V.value = Bottom; ty }) field_types in + let fields = T.FieldId.vector_of_list fields in + let av = + V.Adt + { + def_id; + variant_id = opt_variant_id; + regions; + types; + field_values = fields; + } + in + let ty = T.Adt (def_id, regions, types) in + { V.value = av; V.ty } + +(** Compute an expanded tuple bottom value *) +let compute_expanded_bottom_tuple_value (field_types : T.ety list) : + V.typed_value = + (* Generate the field values *) + let fields = List.map (fun ty -> { V.value = Bottom; ty }) field_types in + let fields = T.FieldId.vector_of_list fields in + let v = V.Tuple fields in + let ty = T.Tuple field_types in + { V.value = v; V.ty } + (** Auxiliary helper to expand [Bottom] values. During compilation, rustc desaggregates the ADT initializations. The @@ -1077,6 +1118,8 @@ let write_place_unwrap (config : C.config) (access : access_kind) (p : E.place) to, say, [Cons Bottom Bottom] (note that field projection contains information about which variant we should project to, which is why we *can* set the variant index when writing one of its fields). + + TODO: rename to express the fact that this function uses a projection *) let expand_bottom_value (config : C.config) (access : access_kind) (tyctx : T.type_def T.TypeDefId.vector) (p : E.place) (remaining_pes : int) @@ -1101,46 +1144,15 @@ let expand_bottom_value (config : C.config) (access : access_kind) match (pe, ty) with | ( Field (ProjAdt (def_id, opt_variant_id), _), T.Adt (def_id', regions, types) ) -> - (* Lookup the definition and check if it is an enumeration - it - should be an enumeration if and only if the projection element - is a field projection with *some* variant id. Retrieve the list - of fields at the same time. *) assert (def_id = def_id'); - let def = T.TypeDefId.nth tyctx def_id in - let fields = - match (def.kind, opt_variant_id) with - | Struct fields, None -> fields - | Enum variants, Some variant_id -> - (* Retrieve the proper variant *) - let variant = T.VariantId.nth variants variant_id in - variant.fields - | _ -> failwith "Unreachable" - in - (* Initialize the expanded value *) - let fields = T.FieldId.vector_to_list fields in - let fields = - List.map - (fun f -> { V.value = Bottom; ty = T.erase_regions f.T.field_ty }) - fields - in - let fields = T.FieldId.vector_of_list fields in - V.Adt - { - def_id; - variant_id = opt_variant_id; - regions; - types; - field_values = fields; - } + compute_expanded_bottom_adt_value tyctx def_id opt_variant_id regions + types | Field (ProjTuple arity, _), T.Tuple tys -> assert (arity = List.length tys); (* Generate the field values *) - let fields = List.map (fun ty -> { V.value = Bottom; ty }) tys in - let fields = T.FieldId.vector_of_list fields in - V.Tuple fields + compute_expanded_bottom_tuple_value tys | _ -> failwith "Unreachable" in - let nv = { V.value = nv; ty } in (* Update the environment by inserting the expanded value at the proper place *) match write_place config access p' nv env with | Ok env -> env @@ -1816,6 +1828,41 @@ let prepare_lplace (config : C.config) (p : E.place) (ctx : C.eval_ctx) : let ctx3 = { ctx with env = env3 } in (ctx3, v) +(** Read the value at a given place. + As long as it is a loan, end the loan *) +let rec end_loan_exactly_at_place (config : C.config) (access : access_kind) + (p : E.place) (ctx : C.eval_ctx) : C.eval_ctx = + let v = read_place_unwrap config access p ctx.env in + match v.V.value with + | V.Loan (V.SharedLoan (bids, _)) -> + let env1 = end_borrows config Outer bids ctx.env in + let ctx1 = { ctx with env = env1 } in + end_loan_exactly_at_place config access p ctx1 + | V.Loan (V.MutLoan bid) -> + let env1 = end_borrow config Outer bid ctx.env in + let ctx1 = { ctx with env = env1 } in + end_loan_exactly_at_place config access p ctx1 + | _ -> ctx + +let set_discriminant (config : C.config) (ctx : C.eval_ctx) (p : E.place) + (variant_id : T.VariantId.id) : + (C.eval_ctx * statement_eval_res) eval_result = + (* Access the value *) + let access = Write in + let env1 = update_env_along_read_place config access p ctx.env in + let ctx1 = { ctx with env = env1 } in + let ctx2 = end_loan_exactly_at_place config access p ctx1 in + let v = read_place_unwrap config access p ctx2.env in + (* Update the value *) + match v.V.value with + | Adt av -> raise Unimplemented + | Bottom -> raise Unimplemented + | Symbolic _ -> + assert (config.mode = SymbolicMode); + (* TODO *) raise Unimplemented + | Concrete _ | Tuple _ | Borrow _ | Loan _ | Assumed _ -> + failwith "Unexpected value" + let rec eval_statement (config : C.config) (ctx : C.eval_ctx) (st : A.statement) : (C.eval_ctx * statement_eval_res) eval_result = match st with @@ -1834,7 +1881,8 @@ let rec eval_statement (config : C.config) (ctx : C.eval_ctx) (st : A.statement) | A.FakeRead p -> let ctx1, _ = prepare_rplace config Read p ctx in Ok (ctx1, Unit) - | A.SetDiscriminant (p, variant_id) -> raise Unimplemented + | A.SetDiscriminant (p, variant_id) -> + set_discriminant config ctx p variant_id | A.Drop p -> let ctx1, v = prepare_lplace config p ctx in let nv = { v with value = V.Bottom } in @@ -1848,9 +1896,13 @@ let rec eval_statement (config : C.config) (ctx : C.eval_ctx) (st : A.statement) | Concrete (Bool b) -> if b = assertion.expected then Ok (ctx1, Unit) else Error Panic | _ -> failwith "Expected a boolean") - | A.Call call -> raise Unimplemented + | A.Call call -> eval_function_call config ctx call | A.Panic -> Error Panic | A.Return -> Ok (ctx, Return) | A.Break i -> Ok (ctx, Break i) | A.Continue i -> Ok (ctx, Continue i) | A.Nop -> Ok (ctx, Unit) + +and eval_function_call (config : C.config) (ctx : C.eval_ctx) (call : A.call) : + (C.eval_ctx * statement_eval_res) eval_result = + raise Unimplemented -- cgit v1.2.3