From 5095a482e9a208db239b909fae1e9c7fea4f5117 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Tue, 25 Jan 2022 12:10:38 +0100 Subject: Make good progress on SymbolicToPure.translate_expansion --- src/Pure.ml | 26 +++++-- src/SymbolicAst.ml | 13 ++-- src/SymbolicToPure.ml | 172 ++++++++++++++++++++++++++++++++++++++++------ src/SynthesizeSymbolic.ml | 39 +++++------ src/TypesUtils.ml | 16 +++++ 5 files changed, 210 insertions(+), 56 deletions(-) diff --git a/src/Pure.ml b/src/Pure.ml index 6549d3fa..a0c8d70c 100644 --- a/src/Pure.ml +++ b/src/Pure.ml @@ -66,6 +66,7 @@ type symbolic_value = { symbolic value was introduced. *) } +(** TODO: remove? *) type value = Concrete of constant_value | Adt of adt_value @@ -114,11 +115,22 @@ type left_value = unit type let_bindings = | Call of var_or_dummy list * call (** The called function and the tuple of returned values *) - | Assignment of var * place + | Assignment of var * operand (** Variable assignment: the introduced variable and the place we read *) - | ExpandEnum of var_or_dummy list * TypeDefId.id * VariantId.id * place - (** When expanding an enumeration with exactly one variant, we first - introduce something like this (with [ExpandEnum]): + | Deconstruct of + var_or_dummy list * (TypeDefId.id * VariantId.id) option * operand + (** This is used in two cases. + + 1. When deconstructing a tuple: + ``` + let (x, y) = p in ... + ``` + (not all languages have syntax like `p.0`, `p.1`... and it is more + readable anyway). + + 2. When expanding an enumeration with one variant. + + In this case, [Deconstruct] has to be understood as: ``` let Cons x tl = ls in ... @@ -134,8 +146,8 @@ type let_bindings = Note that we prefer not handling this case through a match. TODO: actually why not encoding it as a match internally, then - generating the `let Cons ... = ... in ...` if we check there - is exactly one variant?... + generating the `let Cons ... = ... in ...` upon outputting the + code if we detect there is exactly one variant?... *) (** **Rk.:** here, [expression] is not at all equivalent to the expressions @@ -156,7 +168,7 @@ and switch_body = and match_branch = { variant_id : VariantId.id; - vars : symbolic_value list; + vars : var_or_dummy list; branch : expression; } diff --git a/src/SymbolicAst.ml b/src/SymbolicAst.ml index f1939802..45cdc4b2 100644 --- a/src/SymbolicAst.ml +++ b/src/SymbolicAst.ml @@ -40,12 +40,15 @@ type expression = and expansion = | ExpandNoBranch of V.symbolic_expansion * expression (** A symbolic expansion which doesn't generate a branching. - Includes: expansion of borrows, structures, enumerations with - exactly one variant... *) - | ExpandEnum of + Includes: + - concrete expansion + - borrow expansion + *Doesn't* include: + - expansion of ADTs with one variant + *) + | ExpandAdt of (T.VariantId.id option * V.symbolic_value list * expression) list - (** A symbolic expansion of an ADT value which leads to branching (i.e., - a match over an enumeration with strictly more than one variant *) + (** ADT expansion *) | ExpandBool of expression * expression (** A boolean expansion (i.e, an `if ... then ... else ...`) *) | ExpandInt of diff --git a/src/SymbolicToPure.ml b/src/SymbolicToPure.ml index 69c22b09..469021a0 100644 --- a/src/SymbolicToPure.ml +++ b/src/SymbolicToPure.ml @@ -97,7 +97,10 @@ let translate_fun_name (fdef : A.fun_def) (bid : V.BackwardFunctionId.id option) (** Generates a name for a type (simply reuses the name in the definition) *) let translate_type_name (def : T.type_def) : Id.name = def.T.name -type type_context = { type_defs : type_def TypeDefId.Map.t } +type type_context = { + types_infos : TA.type_infos; + type_defs : type_def TypeDefId.Map.t; +} type fun_context = { fun_defs : fun_def FunDefId.Map.t } @@ -110,9 +113,16 @@ type synth_ctx = { declarations : M.declaration_group list; } -type bs_ctx = { types_infos : TA.type_infos } +type bs_ctx = { + type_context : type_context; + fun_def : A.fun_def; + bid : V.BackwardFunctionId.id option; +} (** Body synthesis context *) +let bs_ctx_lookup_type_def (id : TypeDefId.id) (ctx : bs_ctx) : type_def = + TypeDefId.Map.find id ctx.type_context.type_defs + let rec translate_sty (ty : T.sty) : ty = let translate = translate_sty in match ty with @@ -174,12 +184,13 @@ let translate_type_def (def : T.type_def) : type_def = let rec translate_fwd_ty (ctx : bs_ctx) (ty : 'r T.ty) : ty = let translate = translate_fwd_ty ctx in + let types_infos = ctx.type_context.types_infos in match ty with | T.Adt (type_id, regions, tys) -> (* Can't translate types with regions for now *) assert (regions = []); (* No general parametricity for now *) - assert (not (List.exists (TypesUtils.ty_has_borrows ctx.types_infos) tys)); + assert (not (List.exists (TypesUtils.ty_has_borrows types_infos) tys)); (* Translate the type parameters *) let tys = List.map translate tys in Adt (type_id, tys) @@ -190,10 +201,10 @@ let rec translate_fwd_ty (ctx : bs_ctx) (ty : 'r T.ty) : ty = | Integer int_ty -> Integer int_ty | Str -> Str | Array ty -> - assert (not (TypesUtils.ty_has_borrows ctx.types_infos ty)); + assert (not (TypesUtils.ty_has_borrows types_infos ty)); Array (translate ty) | Slice ty -> - assert (not (TypesUtils.ty_has_borrows ctx.types_infos ty)); + assert (not (TypesUtils.ty_has_borrows types_infos ty)); Slice (translate ty) | Ref (_, rty, _) -> translate rty @@ -206,6 +217,7 @@ let rec translate_fwd_ty (ctx : bs_ctx) (ty : 'r T.ty) : ty = let rec translate_back_ty (ctx : bs_ctx) (keep_region : 'r -> bool) (inside_mut : bool) (ty : 'r T.ty) : ty option = let translate = translate_back_ty ctx keep_region inside_mut in + let types_infos = ctx.type_context.types_infos in (* A small helper for "leave" types *) let wrap ty = if inside_mut then Some ty else None in match ty with @@ -213,7 +225,7 @@ let rec translate_back_ty (ctx : bs_ctx) (keep_region : 'r -> bool) match type_id with | T.AdtId _ | Assumed _ -> (* Don't accept ADTs (which are not tuples) with borrows for now *) - assert (not (TypesUtils.ty_has_borrows ctx.types_infos ty)); + assert (not (TypesUtils.ty_has_borrows types_infos ty)); None | T.Tuple -> ( (* Tuples can contain borrows (which we eliminated) *) @@ -226,10 +238,10 @@ let rec translate_back_ty (ctx : bs_ctx) (keep_region : 'r -> bool) | Integer int_ty -> wrap (Integer int_ty) | Str -> wrap Str | Array ty -> ( - assert (not (TypesUtils.ty_has_borrows ctx.types_infos ty)); + assert (not (TypesUtils.ty_has_borrows types_infos ty)); match translate ty with None -> None | Some ty -> Some (Array ty)) | Slice ty -> ( - assert (not (TypesUtils.ty_has_borrows ctx.types_infos ty)); + assert (not (TypesUtils.ty_has_borrows types_infos ty)); match translate ty with None -> None | Some ty -> Some (Slice ty)) | Ref (r, rty, rkind) -> ( match rkind with @@ -336,31 +348,149 @@ let translate_fun_sig (ctx : bs_ctx) (def : A.fun_def) (* Return *) { type_params; inputs; outputs } -let translate_typed_value (v : V.typed_value) (ctx : bs_ctx) : - bs_ctx * typed_value = +let translate_typed_value (ctx : bs_ctx) (v : V.typed_value) : typed_value = + raise Unimplemented + +let fresh_var_for_symbolic_value (sv : V.symbolic_value) (ctx : bs_ctx) : + bs_ctx * var = raise Unimplemented -let rec translate_expression (def : A.fun_def) - (bid : V.BackwardFunctionId.id option) (body : S.expression) (ctx : bs_ctx) - : expression = - match body with +let fresh_vars_for_symbolic_values (svl : V.symbolic_value list) (ctx : bs_ctx) + : bs_ctx * var list = + List.fold_left_map (fun ctx sv -> fresh_var_for_symbolic_value sv ctx) ctx svl + +let get_var_for_symbolic_value (sv : V.symbolic_value) (ctx : bs_ctx) : var = + raise Unimplemented + +(* TODO: move *) +let mk_place_from_var (v : var) : place = { var = v.id; projection = [] } + +(* TODO: move *) +let type_def_is_enum (def : type_def) : bool = + match def.kind with Struct _ -> false | Enum _ -> true + +let rec translate_expression (e : S.expression) (ctx : bs_ctx) : expression = + match e with | S.Return v -> - let _, v = translate_typed_value v ctx in + let v = translate_typed_value ctx v in Return (Value v) | Panic -> Panic - | FunCall (call, e) -> raise Unimplemented + | FunCall (call, e) -> + (* Translate the function call *) + let type_params = List.map (translate_fwd_ty ctx) call.type_params in + let args = List.map (translate_typed_value ctx) call.args in + let args = List.map (fun v -> Value v) args in + let ctx, dest = fresh_var_for_symbolic_value call.dest ctx in + let func = + match call.call_id with + | S.Fun (A.Local fid, _) -> Local (fid, None) + | S.Fun (A.Assumed fid, _) -> Assumed fid + | S.Unop unop -> Unop unop + | S.Binop binop -> Binop binop + in + let call = { func; type_params; args } in + (* Translate the next expression *) + let e = translate_expression e ctx in + (* Put together *) + Let (Call ([ Var dest ], call), e) | EndAbstraction (abs, e) -> raise Unimplemented - | Expansion (sv, exp) -> raise Unimplemented + | Expansion (sv, exp) -> translate_expansion sv exp ctx | Meta (_, e) -> (* We ignore the meta information *) - translate_expression def bid e ctx + translate_expression e ctx + +and translate_expansion (sv : V.symbolic_value) (exp : S.expansion) + (ctx : bs_ctx) : expression = + (* Translate the scrutinee *) + let scrutinee_var = get_var_for_symbolic_value sv ctx in + let scrutinee = Place (mk_place_from_var scrutinee_var) in + (* Translate the branches *) + match exp with + | ExpandNoBranch (_, _) -> raise Unimplemented + | ExpandAdt branches -> ( + (* We don't do the same thing if there is a branching or not *) + match branches with + | [] -> failwith "Unreachable" + | [ (variant_id, svl, branch) ] -> ( + let type_id, _, _ = TypesUtils.ty_as_adt sv.V.sv_ty in + let ctx, vars = fresh_vars_for_symbolic_values svl ctx in + let branch = translate_expression branch ctx in + match type_id with + | T.AdtId adt_id -> + (* Detect if this is an enumeration or not *) + let tdef = bs_ctx_lookup_type_def adt_id ctx in + let is_enum = type_def_is_enum tdef in + if is_enum then + (* This is an enumeration: introduce an [ExpandEnum] let-binding *) + let variant_id = Option.get variant_id in + let vars = List.map (fun x -> Var x) vars in + Let + ( Deconstruct (vars, Some (adt_id, variant_id), scrutinee), + branch ) + else + (* This is not an enumeration: introduce let-bindings for every + * field *) + let gen_field_proj (field_id : FieldId.id) : operand = + let pkind = E.ProjAdt (adt_id, None) in + let pe : projection_elem = { pkind; field_id } in + let projection = [ pe ] in + let place = { var = scrutinee_var.id; projection } in + Place place + in + let id_var_pairs = FieldId.mapi (fun fid v -> (fid, v)) vars in + List.fold_right + (fun (fid, var) e -> + let field_proj = gen_field_proj fid in + Let (Assignment (var, field_proj), e)) + id_var_pairs branch + | T.Tuple -> raise Unimplemented + | T.Assumed T.Box -> + (* There should be exactly one variable *) + let var = + match vars with [ v ] -> v | _ -> failwith "Unreachable" + in + (* We simply introduce an assignment - the box type is the + * identity when extracted (`box a == a`) *) + Let (Assignment (var, scrutinee), branch)) + | branches -> + let translate_branch (variant_id : T.VariantId.id option) + (svl : V.symbolic_value list) (branch : S.expression) : + match_branch = + (* There *must* be a variant id - otherwise there can't be several branches *) + let variant_id = Option.get variant_id in + let ctx, vars = fresh_vars_for_symbolic_values svl ctx in + let vars = List.map (fun x -> Var x) vars in + let branch = translate_expression branch ctx in + { variant_id; vars; branch } + in + let branches = + List.map (fun (vid, svl, e) -> translate_branch vid svl e) branches + in + Switch (scrutinee, Match branches)) + | ExpandBool (true_e, false_e) -> + (* We don't need to update the context: we don't introduce any + * new values/variables *) + let true_e = translate_expression true_e ctx in + let false_e = translate_expression false_e ctx in + Switch (scrutinee, If (true_e, false_e)) + | ExpandInt (int_ty, branches, otherwise) -> + let translate_branch ((v, branch_e) : V.scalar_value * S.expression) : + scalar_value * expression = + (* We don't need to update the context: we don't introduce any + * new values/variables *) + let branch_e = translate_expression branch_e ctx in + (v, branch_e) + in + let branches = List.map translate_branch branches in + let otherwise = translate_expression otherwise ctx in + Switch (scrutinee, SwitchInt (int_ty, branches, otherwise)) -let translate_fun_def (types_infos : TA.type_infos) (def : A.fun_def) +let translate_fun_def (type_context : type_context) (def : A.fun_def) (bid : V.BackwardFunctionId.id option) (body : S.expression) : fun_def = - let bs_ctx = { types_infos } in + let bs_ctx = { type_context; fun_def = def; bid } in (* Translate the function *) let def_id = def.A.def_id in let name = translate_fun_name def bid in let signature = translate_fun_sig bs_ctx def bid in - let body = translate_expression def bid body bs_ctx in + let body = translate_expression body bs_ctx in { def_id; name; signature; body } diff --git a/src/SynthesizeSymbolic.ml b/src/SynthesizeSymbolic.ml index 1a30687c..c0d4489c 100644 --- a/src/SynthesizeSymbolic.ml +++ b/src/SynthesizeSymbolic.ml @@ -49,29 +49,22 @@ let synthesize_symbolic_expansion (sv : V.symbolic_value) assert (otherwise_see = None); (* Return *) ExpandInt (int_ty, branches, otherwise) - | T.Adt (_, _, _) -> ( - (* An ADT expansion can lead to branching: check if this is the case *) - match ls with - | [] -> failwith "Ill-formed ADT expansion" - | [ (see, exp) ] -> - (* No branching *) - ExpandNoBranch (Option.get see, exp) - | ls -> - (* Branching: it is necessarily an enumeration expansion *) - let get_variant (see : V.symbolic_expansion option) : - T.VariantId.id option * V.symbolic_value list = - match see with - | Some (V.SeAdt (vid, fields)) -> (vid, fields) - | _ -> failwith "Ill-formed branching ADT expansion" - in - let exp = - List.map - (fun (see, exp) -> - let vid, fields = get_variant see in - (vid, fields, exp)) - ls - in - ExpandEnum exp) + | T.Adt (_, _, _) -> + (* Branching: it is necessarily an enumeration expansion *) + let get_variant (see : V.symbolic_expansion option) : + T.VariantId.id option * V.symbolic_value list = + match see with + | Some (V.SeAdt (vid, fields)) -> (vid, fields) + | _ -> failwith "Ill-formed branching ADT expansion" + in + let exp = + List.map + (fun (see, exp) -> + let vid, fields = get_variant see in + (vid, fields, exp)) + ls + in + ExpandAdt exp | T.Ref (_, _, _) -> ( (* Reference expansion: there should be one branch *) match ls with diff --git a/src/TypesUtils.ml b/src/TypesUtils.ml index 86076469..1eac5cee 100644 --- a/src/TypesUtils.ml +++ b/src/TypesUtils.ml @@ -21,6 +21,22 @@ let type_def_get_fields (def : type_def) (opt_variant_id : VariantId.id option) let ty_is_unit (ty : 'r ty) : bool = match ty with Adt (Tuple, [], []) -> true | _ -> false +let ty_is_adt (ty : 'r ty) : bool = + match ty with Adt (_, _, _) -> true | _ -> false + +let ty_as_adt (ty : 'r ty) : type_id * 'r list * 'r ty list = + match ty with + | Adt (id, regions, tys) -> (id, regions, tys) + | _ -> failwith "Unreachable" + +let ty_is_custom_adt (ty : 'r ty) : bool = + match ty with Adt (AdtId _, _, _) -> true | _ -> false + +let ty_as_custom_adt (ty : 'r ty) : TypeDefId.id * 'r list * 'r ty list = + match ty with + | Adt (AdtId id, regions, tys) -> (id, regions, tys) + | _ -> failwith "Unreachable" + (** The unit type *) let mk_unit_ty : ety = Adt (Tuple, [], []) -- cgit v1.2.3