summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorSon Ho2022-01-25 12:10:38 +0100
committerSon Ho2022-01-25 12:10:38 +0100
commit5095a482e9a208db239b909fae1e9c7fea4f5117 (patch)
treebb3d2f0d5b9e143a2614f4724786e9baf9b52af9 /src
parent7870b9f816b095164d89a7ea07a9bc29bf8af875 (diff)
Make good progress on SymbolicToPure.translate_expansion
Diffstat (limited to 'src')
-rw-r--r--src/Pure.ml26
-rw-r--r--src/SymbolicAst.ml13
-rw-r--r--src/SymbolicToPure.ml172
-rw-r--r--src/SynthesizeSymbolic.ml39
-rw-r--r--src/TypesUtils.ml16
5 files changed, 210 insertions, 56 deletions
diff --git a/src/Pure.ml b/src/Pure.ml
index 6549d3fa..a0c8d70c 100644
--- a/src/Pure.ml
+++ b/src/Pure.ml
@@ -66,6 +66,7 @@ type symbolic_value = {
symbolic value was introduced.
*)
}
+(** TODO: remove? *)
type value = Concrete of constant_value | Adt of adt_value
@@ -114,11 +115,22 @@ type left_value = unit
type let_bindings =
| Call of var_or_dummy list * call
(** The called function and the tuple of returned values *)
- | Assignment of var * place
+ | Assignment of var * operand
(** Variable assignment: the introduced variable and the place we read *)
- | ExpandEnum of var_or_dummy list * TypeDefId.id * VariantId.id * place
- (** When expanding an enumeration with exactly one variant, we first
- introduce something like this (with [ExpandEnum]):
+ | Deconstruct of
+ var_or_dummy list * (TypeDefId.id * VariantId.id) option * operand
+ (** This is used in two cases.
+
+ 1. When deconstructing a tuple:
+ ```
+ let (x, y) = p in ...
+ ```
+ (not all languages have syntax like `p.0`, `p.1`... and it is more
+ readable anyway).
+
+ 2. When expanding an enumeration with one variant.
+
+ In this case, [Deconstruct] has to be understood as:
```
let Cons x tl = ls in
...
@@ -134,8 +146,8 @@ type let_bindings =
Note that we prefer not handling this case through a match.
TODO: actually why not encoding it as a match internally, then
- generating the `let Cons ... = ... in ...` if we check there
- is exactly one variant?...
+ generating the `let Cons ... = ... in ...` upon outputting the
+ code if we detect there is exactly one variant?...
*)
(** **Rk.:** here, [expression] is not at all equivalent to the expressions
@@ -156,7 +168,7 @@ and switch_body =
and match_branch = {
variant_id : VariantId.id;
- vars : symbolic_value list;
+ vars : var_or_dummy list;
branch : expression;
}
diff --git a/src/SymbolicAst.ml b/src/SymbolicAst.ml
index f1939802..45cdc4b2 100644
--- a/src/SymbolicAst.ml
+++ b/src/SymbolicAst.ml
@@ -40,12 +40,15 @@ type expression =
and expansion =
| ExpandNoBranch of V.symbolic_expansion * expression
(** A symbolic expansion which doesn't generate a branching.
- Includes: expansion of borrows, structures, enumerations with
- exactly one variant... *)
- | ExpandEnum of
+ Includes:
+ - concrete expansion
+ - borrow expansion
+ *Doesn't* include:
+ - expansion of ADTs with one variant
+ *)
+ | ExpandAdt of
(T.VariantId.id option * V.symbolic_value list * expression) list
- (** A symbolic expansion of an ADT value which leads to branching (i.e.,
- a match over an enumeration with strictly more than one variant *)
+ (** ADT expansion *)
| ExpandBool of expression * expression
(** A boolean expansion (i.e, an `if ... then ... else ...`) *)
| ExpandInt of
diff --git a/src/SymbolicToPure.ml b/src/SymbolicToPure.ml
index 69c22b09..469021a0 100644
--- a/src/SymbolicToPure.ml
+++ b/src/SymbolicToPure.ml
@@ -97,7 +97,10 @@ let translate_fun_name (fdef : A.fun_def) (bid : V.BackwardFunctionId.id option)
(** Generates a name for a type (simply reuses the name in the definition) *)
let translate_type_name (def : T.type_def) : Id.name = def.T.name
-type type_context = { type_defs : type_def TypeDefId.Map.t }
+type type_context = {
+ types_infos : TA.type_infos;
+ type_defs : type_def TypeDefId.Map.t;
+}
type fun_context = { fun_defs : fun_def FunDefId.Map.t }
@@ -110,9 +113,16 @@ type synth_ctx = {
declarations : M.declaration_group list;
}
-type bs_ctx = { types_infos : TA.type_infos }
+type bs_ctx = {
+ type_context : type_context;
+ fun_def : A.fun_def;
+ bid : V.BackwardFunctionId.id option;
+}
(** Body synthesis context *)
+let bs_ctx_lookup_type_def (id : TypeDefId.id) (ctx : bs_ctx) : type_def =
+ TypeDefId.Map.find id ctx.type_context.type_defs
+
let rec translate_sty (ty : T.sty) : ty =
let translate = translate_sty in
match ty with
@@ -174,12 +184,13 @@ let translate_type_def (def : T.type_def) : type_def =
let rec translate_fwd_ty (ctx : bs_ctx) (ty : 'r T.ty) : ty =
let translate = translate_fwd_ty ctx in
+ let types_infos = ctx.type_context.types_infos in
match ty with
| T.Adt (type_id, regions, tys) ->
(* Can't translate types with regions for now *)
assert (regions = []);
(* No general parametricity for now *)
- assert (not (List.exists (TypesUtils.ty_has_borrows ctx.types_infos) tys));
+ assert (not (List.exists (TypesUtils.ty_has_borrows types_infos) tys));
(* Translate the type parameters *)
let tys = List.map translate tys in
Adt (type_id, tys)
@@ -190,10 +201,10 @@ let rec translate_fwd_ty (ctx : bs_ctx) (ty : 'r T.ty) : ty =
| Integer int_ty -> Integer int_ty
| Str -> Str
| Array ty ->
- assert (not (TypesUtils.ty_has_borrows ctx.types_infos ty));
+ assert (not (TypesUtils.ty_has_borrows types_infos ty));
Array (translate ty)
| Slice ty ->
- assert (not (TypesUtils.ty_has_borrows ctx.types_infos ty));
+ assert (not (TypesUtils.ty_has_borrows types_infos ty));
Slice (translate ty)
| Ref (_, rty, _) -> translate rty
@@ -206,6 +217,7 @@ let rec translate_fwd_ty (ctx : bs_ctx) (ty : 'r T.ty) : ty =
let rec translate_back_ty (ctx : bs_ctx) (keep_region : 'r -> bool)
(inside_mut : bool) (ty : 'r T.ty) : ty option =
let translate = translate_back_ty ctx keep_region inside_mut in
+ let types_infos = ctx.type_context.types_infos in
(* A small helper for "leave" types *)
let wrap ty = if inside_mut then Some ty else None in
match ty with
@@ -213,7 +225,7 @@ let rec translate_back_ty (ctx : bs_ctx) (keep_region : 'r -> bool)
match type_id with
| T.AdtId _ | Assumed _ ->
(* Don't accept ADTs (which are not tuples) with borrows for now *)
- assert (not (TypesUtils.ty_has_borrows ctx.types_infos ty));
+ assert (not (TypesUtils.ty_has_borrows types_infos ty));
None
| T.Tuple -> (
(* Tuples can contain borrows (which we eliminated) *)
@@ -226,10 +238,10 @@ let rec translate_back_ty (ctx : bs_ctx) (keep_region : 'r -> bool)
| Integer int_ty -> wrap (Integer int_ty)
| Str -> wrap Str
| Array ty -> (
- assert (not (TypesUtils.ty_has_borrows ctx.types_infos ty));
+ assert (not (TypesUtils.ty_has_borrows types_infos ty));
match translate ty with None -> None | Some ty -> Some (Array ty))
| Slice ty -> (
- assert (not (TypesUtils.ty_has_borrows ctx.types_infos ty));
+ assert (not (TypesUtils.ty_has_borrows types_infos ty));
match translate ty with None -> None | Some ty -> Some (Slice ty))
| Ref (r, rty, rkind) -> (
match rkind with
@@ -336,31 +348,149 @@ let translate_fun_sig (ctx : bs_ctx) (def : A.fun_def)
(* Return *)
{ type_params; inputs; outputs }
-let translate_typed_value (v : V.typed_value) (ctx : bs_ctx) :
- bs_ctx * typed_value =
+let translate_typed_value (ctx : bs_ctx) (v : V.typed_value) : typed_value =
+ raise Unimplemented
+
+let fresh_var_for_symbolic_value (sv : V.symbolic_value) (ctx : bs_ctx) :
+ bs_ctx * var =
raise Unimplemented
-let rec translate_expression (def : A.fun_def)
- (bid : V.BackwardFunctionId.id option) (body : S.expression) (ctx : bs_ctx)
- : expression =
- match body with
+let fresh_vars_for_symbolic_values (svl : V.symbolic_value list) (ctx : bs_ctx)
+ : bs_ctx * var list =
+ List.fold_left_map (fun ctx sv -> fresh_var_for_symbolic_value sv ctx) ctx svl
+
+let get_var_for_symbolic_value (sv : V.symbolic_value) (ctx : bs_ctx) : var =
+ raise Unimplemented
+
+(* TODO: move *)
+let mk_place_from_var (v : var) : place = { var = v.id; projection = [] }
+
+(* TODO: move *)
+let type_def_is_enum (def : type_def) : bool =
+ match def.kind with Struct _ -> false | Enum _ -> true
+
+let rec translate_expression (e : S.expression) (ctx : bs_ctx) : expression =
+ match e with
| S.Return v ->
- let _, v = translate_typed_value v ctx in
+ let v = translate_typed_value ctx v in
Return (Value v)
| Panic -> Panic
- | FunCall (call, e) -> raise Unimplemented
+ | FunCall (call, e) ->
+ (* Translate the function call *)
+ let type_params = List.map (translate_fwd_ty ctx) call.type_params in
+ let args = List.map (translate_typed_value ctx) call.args in
+ let args = List.map (fun v -> Value v) args in
+ let ctx, dest = fresh_var_for_symbolic_value call.dest ctx in
+ let func =
+ match call.call_id with
+ | S.Fun (A.Local fid, _) -> Local (fid, None)
+ | S.Fun (A.Assumed fid, _) -> Assumed fid
+ | S.Unop unop -> Unop unop
+ | S.Binop binop -> Binop binop
+ in
+ let call = { func; type_params; args } in
+ (* Translate the next expression *)
+ let e = translate_expression e ctx in
+ (* Put together *)
+ Let (Call ([ Var dest ], call), e)
| EndAbstraction (abs, e) -> raise Unimplemented
- | Expansion (sv, exp) -> raise Unimplemented
+ | Expansion (sv, exp) -> translate_expansion sv exp ctx
| Meta (_, e) ->
(* We ignore the meta information *)
- translate_expression def bid e ctx
+ translate_expression e ctx
+
+and translate_expansion (sv : V.symbolic_value) (exp : S.expansion)
+ (ctx : bs_ctx) : expression =
+ (* Translate the scrutinee *)
+ let scrutinee_var = get_var_for_symbolic_value sv ctx in
+ let scrutinee = Place (mk_place_from_var scrutinee_var) in
+ (* Translate the branches *)
+ match exp with
+ | ExpandNoBranch (_, _) -> raise Unimplemented
+ | ExpandAdt branches -> (
+ (* We don't do the same thing if there is a branching or not *)
+ match branches with
+ | [] -> failwith "Unreachable"
+ | [ (variant_id, svl, branch) ] -> (
+ let type_id, _, _ = TypesUtils.ty_as_adt sv.V.sv_ty in
+ let ctx, vars = fresh_vars_for_symbolic_values svl ctx in
+ let branch = translate_expression branch ctx in
+ match type_id with
+ | T.AdtId adt_id ->
+ (* Detect if this is an enumeration or not *)
+ let tdef = bs_ctx_lookup_type_def adt_id ctx in
+ let is_enum = type_def_is_enum tdef in
+ if is_enum then
+ (* This is an enumeration: introduce an [ExpandEnum] let-binding *)
+ let variant_id = Option.get variant_id in
+ let vars = List.map (fun x -> Var x) vars in
+ Let
+ ( Deconstruct (vars, Some (adt_id, variant_id), scrutinee),
+ branch )
+ else
+ (* This is not an enumeration: introduce let-bindings for every
+ * field *)
+ let gen_field_proj (field_id : FieldId.id) : operand =
+ let pkind = E.ProjAdt (adt_id, None) in
+ let pe : projection_elem = { pkind; field_id } in
+ let projection = [ pe ] in
+ let place = { var = scrutinee_var.id; projection } in
+ Place place
+ in
+ let id_var_pairs = FieldId.mapi (fun fid v -> (fid, v)) vars in
+ List.fold_right
+ (fun (fid, var) e ->
+ let field_proj = gen_field_proj fid in
+ Let (Assignment (var, field_proj), e))
+ id_var_pairs branch
+ | T.Tuple -> raise Unimplemented
+ | T.Assumed T.Box ->
+ (* There should be exactly one variable *)
+ let var =
+ match vars with [ v ] -> v | _ -> failwith "Unreachable"
+ in
+ (* We simply introduce an assignment - the box type is the
+ * identity when extracted (`box a == a`) *)
+ Let (Assignment (var, scrutinee), branch))
+ | branches ->
+ let translate_branch (variant_id : T.VariantId.id option)
+ (svl : V.symbolic_value list) (branch : S.expression) :
+ match_branch =
+ (* There *must* be a variant id - otherwise there can't be several branches *)
+ let variant_id = Option.get variant_id in
+ let ctx, vars = fresh_vars_for_symbolic_values svl ctx in
+ let vars = List.map (fun x -> Var x) vars in
+ let branch = translate_expression branch ctx in
+ { variant_id; vars; branch }
+ in
+ let branches =
+ List.map (fun (vid, svl, e) -> translate_branch vid svl e) branches
+ in
+ Switch (scrutinee, Match branches))
+ | ExpandBool (true_e, false_e) ->
+ (* We don't need to update the context: we don't introduce any
+ * new values/variables *)
+ let true_e = translate_expression true_e ctx in
+ let false_e = translate_expression false_e ctx in
+ Switch (scrutinee, If (true_e, false_e))
+ | ExpandInt (int_ty, branches, otherwise) ->
+ let translate_branch ((v, branch_e) : V.scalar_value * S.expression) :
+ scalar_value * expression =
+ (* We don't need to update the context: we don't introduce any
+ * new values/variables *)
+ let branch_e = translate_expression branch_e ctx in
+ (v, branch_e)
+ in
+ let branches = List.map translate_branch branches in
+ let otherwise = translate_expression otherwise ctx in
+ Switch (scrutinee, SwitchInt (int_ty, branches, otherwise))
-let translate_fun_def (types_infos : TA.type_infos) (def : A.fun_def)
+let translate_fun_def (type_context : type_context) (def : A.fun_def)
(bid : V.BackwardFunctionId.id option) (body : S.expression) : fun_def =
- let bs_ctx = { types_infos } in
+ let bs_ctx = { type_context; fun_def = def; bid } in
(* Translate the function *)
let def_id = def.A.def_id in
let name = translate_fun_name def bid in
let signature = translate_fun_sig bs_ctx def bid in
- let body = translate_expression def bid body bs_ctx in
+ let body = translate_expression body bs_ctx in
{ def_id; name; signature; body }
diff --git a/src/SynthesizeSymbolic.ml b/src/SynthesizeSymbolic.ml
index 1a30687c..c0d4489c 100644
--- a/src/SynthesizeSymbolic.ml
+++ b/src/SynthesizeSymbolic.ml
@@ -49,29 +49,22 @@ let synthesize_symbolic_expansion (sv : V.symbolic_value)
assert (otherwise_see = None);
(* Return *)
ExpandInt (int_ty, branches, otherwise)
- | T.Adt (_, _, _) -> (
- (* An ADT expansion can lead to branching: check if this is the case *)
- match ls with
- | [] -> failwith "Ill-formed ADT expansion"
- | [ (see, exp) ] ->
- (* No branching *)
- ExpandNoBranch (Option.get see, exp)
- | ls ->
- (* Branching: it is necessarily an enumeration expansion *)
- let get_variant (see : V.symbolic_expansion option) :
- T.VariantId.id option * V.symbolic_value list =
- match see with
- | Some (V.SeAdt (vid, fields)) -> (vid, fields)
- | _ -> failwith "Ill-formed branching ADT expansion"
- in
- let exp =
- List.map
- (fun (see, exp) ->
- let vid, fields = get_variant see in
- (vid, fields, exp))
- ls
- in
- ExpandEnum exp)
+ | T.Adt (_, _, _) ->
+ (* Branching: it is necessarily an enumeration expansion *)
+ let get_variant (see : V.symbolic_expansion option) :
+ T.VariantId.id option * V.symbolic_value list =
+ match see with
+ | Some (V.SeAdt (vid, fields)) -> (vid, fields)
+ | _ -> failwith "Ill-formed branching ADT expansion"
+ in
+ let exp =
+ List.map
+ (fun (see, exp) ->
+ let vid, fields = get_variant see in
+ (vid, fields, exp))
+ ls
+ in
+ ExpandAdt exp
| T.Ref (_, _, _) -> (
(* Reference expansion: there should be one branch *)
match ls with
diff --git a/src/TypesUtils.ml b/src/TypesUtils.ml
index 86076469..1eac5cee 100644
--- a/src/TypesUtils.ml
+++ b/src/TypesUtils.ml
@@ -21,6 +21,22 @@ let type_def_get_fields (def : type_def) (opt_variant_id : VariantId.id option)
let ty_is_unit (ty : 'r ty) : bool =
match ty with Adt (Tuple, [], []) -> true | _ -> false
+let ty_is_adt (ty : 'r ty) : bool =
+ match ty with Adt (_, _, _) -> true | _ -> false
+
+let ty_as_adt (ty : 'r ty) : type_id * 'r list * 'r ty list =
+ match ty with
+ | Adt (id, regions, tys) -> (id, regions, tys)
+ | _ -> failwith "Unreachable"
+
+let ty_is_custom_adt (ty : 'r ty) : bool =
+ match ty with Adt (AdtId _, _, _) -> true | _ -> false
+
+let ty_as_custom_adt (ty : 'r ty) : TypeDefId.id * 'r list * 'r ty list =
+ match ty with
+ | Adt (AdtId id, regions, tys) -> (id, regions, tys)
+ | _ -> failwith "Unreachable"
+
(** The unit type *)
let mk_unit_ty : ety = Adt (Tuple, [], [])