summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSon Ho2022-01-28 10:26:59 +0100
committerSon Ho2022-01-28 10:26:59 +0100
commit7deb7a2bde6d6bcdf14aac4f68f336bc498b964b (patch)
tree844f41bb7a427b15b75cf5827bb4519b2930ae88
parent1153b33184118cd4ee8d4ebca6081183879c0b49 (diff)
Make substantial simplifications to the pure AST
Diffstat (limited to '')
-rw-r--r--src/PrintPure.ml154
-rw-r--r--src/Pure.ml154
-rw-r--r--src/PureMicroPasses.ml130
-rw-r--r--src/SymbolicToPure.ml69
4 files changed, 275 insertions, 232 deletions
diff --git a/src/PrintPure.ml b/src/PrintPure.ml
index 064d8b9d..3e68db90 100644
--- a/src/PrintPure.ml
+++ b/src/PrintPure.ml
@@ -195,18 +195,9 @@ let var_to_string (fmt : type_formatter) (v : var) : string =
let var_or_dummy_to_string (fmt : value_formatter) (v : var_or_dummy) : string =
match v with
- | Var v -> var_to_string (value_to_type_formatter fmt) v
+ | Var (v, _) -> var_to_string (value_to_type_formatter fmt) v
| Dummy -> "_"
-let rec typed_lvalue_to_string (fmt : value_formatter) (v : typed_lvalue) :
- string =
- match v.value with
- | LvVar var -> var_or_dummy_to_string fmt var
- | LvTuple values ->
- "("
- ^ String.concat ", " (List.map (typed_lvalue_to_string fmt) values)
- ^ ")"
-
let rec projection_to_string (fmt : ast_formatter) (inside : string)
(p : projection) : string =
match p with
@@ -230,48 +221,63 @@ let place_to_string (fmt : ast_formatter) (p : place) : string =
let var = fmt.var_id_to_string p.var in
projection_to_string fmt var p.projection
+let adt_g_value_to_string (fmt : value_formatter)
+ (value_to_string : 'v -> string) (variant_id : VariantId.id option)
+ (field_values : 'v list) (ty : ty) : string =
+ let field_values = List.map value_to_string field_values in
+ match ty with
+ | Adt (T.Tuple, _) ->
+ (* Tuple *)
+ "(" ^ String.concat ", " field_values ^ ")"
+ | Adt (T.AdtId def_id, _) ->
+ (* "Regular" ADT *)
+ let adt_ident =
+ match variant_id with
+ | Some vid -> fmt.adt_variant_to_string def_id vid
+ | None -> fmt.type_def_id_to_string def_id
+ in
+ if field_values <> [] then
+ match fmt.adt_field_names def_id variant_id with
+ | None ->
+ let field_values = String.concat ", " field_values in
+ adt_ident ^ " (" ^ field_values ^ ")"
+ | Some field_names ->
+ let field_values = List.combine field_names field_values in
+ let field_values =
+ List.map
+ (fun (field, value) -> field ^ " = " ^ value ^ ";")
+ field_values
+ in
+ let field_values = String.concat " " field_values in
+ adt_ident ^ " { " ^ field_values ^ " }"
+ else adt_ident
+ | Adt (T.Assumed aty, _) -> (
+ (* Assumed type *)
+ match aty with
+ | Box ->
+ (* Box values should have been eliminated *)
+ failwith "Unreachable")
+ | _ -> failwith "Inconsistent typed value"
+
+let rec typed_lvalue_to_string (fmt : value_formatter) (v : typed_lvalue) :
+ string =
+ match v.value with
+ | LvVar var -> var_or_dummy_to_string fmt var
+ | LvAdt av ->
+ adt_g_value_to_string fmt
+ (typed_lvalue_to_string fmt)
+ av.variant_id av.field_values v.ty
+
let rec typed_rvalue_to_string (fmt : ast_formatter) (v : typed_rvalue) : string
=
match v.value with
| RvConcrete cv -> Print.Values.constant_value_to_string cv
| RvPlace p -> place_to_string fmt p
- | RvAdt av -> (
- let field_values =
- List.map (typed_rvalue_to_string fmt) av.field_values
- in
- match v.ty with
- | Adt (T.Tuple, _) ->
- (* Tuple *)
- "(" ^ String.concat ", " field_values ^ ")"
- | Adt (T.AdtId def_id, _) ->
- (* "Regular" ADT *)
- let adt_ident =
- match av.variant_id with
- | Some vid -> fmt.adt_variant_to_string def_id vid
- | None -> fmt.type_def_id_to_string def_id
- in
- if field_values <> [] then
- match fmt.adt_field_names def_id av.variant_id with
- | None ->
- let field_values = String.concat ", " field_values in
- adt_ident ^ " (" ^ field_values ^ ")"
- | Some field_names ->
- let field_values = List.combine field_names field_values in
- let field_values =
- List.map
- (fun (field, value) -> field ^ " = " ^ value ^ ";")
- field_values
- in
- let field_values = String.concat " " field_values in
- adt_ident ^ " { " ^ field_values ^ " }"
- else adt_ident
- | Adt (T.Assumed aty, _) -> (
- (* Assumed type *)
- match aty with
- | Box ->
- (* Box values should have been eliminated *)
- failwith "Unreachable")
- | _ -> failwith "Inconsistent typed value")
+ | RvAdt av ->
+ adt_g_value_to_string
+ (ast_to_value_formatter fmt)
+ (typed_rvalue_to_string fmt)
+ av.variant_id av.field_values v.ty
let fun_sig_to_string (fmt : ast_formatter) (sg : fun_sig) : string =
let ty_fmt = ast_to_type_formatter fmt in
@@ -365,19 +371,8 @@ and let_to_string (fmt : ast_formatter) (indent : string) (indent_incr : string)
let e = expression_to_string fmt indent indent_incr e in
let val_fmt = ast_to_value_formatter fmt in
match lb with
- | Call (lvs, call) ->
- let lvs =
- List.map (fun (lv, _) -> typed_lvalue_to_string val_fmt lv) lvs
- in
- let lvs =
- match lvs with
- | [] ->
- (* Can happen with backward functions which don't give back
- * anything (shared borrows only) *)
- "()"
- | [ lv ] -> lv
- | lvs -> "(" ^ String.concat " " lvs ^ ")"
- in
+ | Call (lv, call) ->
+ let lv = typed_lvalue_to_string val_fmt lv in
let ty_fmt = ast_to_type_formatter fmt in
let tys = List.map (ty_to_string ty_fmt) call.type_params in
let args = List.map (typed_rvalue_to_string fmt) call.args in
@@ -387,25 +382,11 @@ and let_to_string (fmt : ast_formatter) (indent : string) (indent_incr : string)
if all_args = [] then fun_id
else fun_id ^ " " ^ String.concat " " all_args
in
- indent ^ "let " ^ lvs ^ " = " ^ call ^ " in\n" ^ e
- | Assign (lv, _, rv, _) ->
+ indent ^ "let " ^ lv ^ " = " ^ call ^ " in\n" ^ e
+ | Assign (lv, rv, _) ->
let lv = typed_lvalue_to_string val_fmt lv in
let rv = typed_rvalue_to_string fmt rv in
indent ^ "let " ^ lv ^ " = " ^ rv ^ " in\n" ^ e
- | Deconstruct (lvs, opt_adt_id, rv, _) ->
- let rv = typed_rvalue_to_string fmt rv in
- let lvs =
- List.map (fun (lv, _) -> var_or_dummy_to_string val_fmt lv) lvs
- in
- let lvs =
- match opt_adt_id with
- | None -> "(" ^ String.concat ", " lvs ^ ")"
- | Some (adt_id, variant_id) ->
- let cons = fmt.adt_variant_to_string adt_id variant_id in
- let lvs = if lvs = [] then "" else " " ^ String.concat " " lvs in
- cons ^ lvs
- in
- indent ^ "let " ^ lvs ^ " = " ^ rv ^ " in\n" ^ e
and switch_to_string (fmt : ast_formatter) (indent : string)
(indent_incr : string) (scrutinee : typed_rvalue) (body : switch_body) :
@@ -435,25 +416,8 @@ and switch_to_string (fmt : ast_formatter) (indent : string)
| Match branches ->
let val_fmt = ast_to_value_formatter fmt in
let branch_to_string (b : match_branch) : string =
- let adt_id =
- match scrutinee.ty with
- | Adt (type_id, _) -> (
- match type_id with
- | T.AdtId id -> id
- | T.Tuple | T.Assumed T.Box ->
- (* We can't match over a tuple or a box value *)
- failwith "Unreachable")
- | _ -> failwith "Unreachable"
- in
- let cons = fmt.adt_variant_to_string adt_id b.variant_id in
- let pats =
- if b.vars = [] then ""
- else
- " "
- ^ String.concat " "
- (List.map (var_or_dummy_to_string val_fmt) b.vars)
- in
- indent ^ "| " ^ cons ^ pats ^ " ->\n"
+ let pat = typed_lvalue_to_string val_fmt b.pat in
+ indent ^ "| " ^ pat ^ " ->\n"
^ expression_to_string fmt indent1 indent_incr b.branch
in
let branches = List.map branch_to_string branches in
diff --git a/src/Pure.ml b/src/Pure.ml
index ee4e74bb..61d2d130 100644
--- a/src/Pure.ml
+++ b/src/Pure.ml
@@ -95,22 +95,10 @@ type var = {
itself.
*)
-type var_or_dummy = Var of var | Dummy (** Ignored value: `_`. *)
-
-(** A left value (which appears on the left of assignments *)
-type lvalue =
- | LvVar of var_or_dummy
- | LvTuple of typed_lvalue list
- (** Rk.: for now we don't support general ADTs *)
-
-and typed_lvalue = { value : lvalue; ty : ty }
-
type projection_elem = { pkind : E.field_proj_kind; field_id : FieldId.id }
type projection = projection_elem list
-type place = { var : VarId.id; projection : projection }
-
type mplace = { name : string option; projection : projection }
(** "Meta" place.
@@ -119,6 +107,112 @@ type mplace = { name : string option; projection : projection }
we introduce.
*)
+type place = { var : VarId.id; projection : projection }
+
+(** Ancestor for [iter_var_or_dummy] iter visitor *)
+class ['self] iter_value_base =
+ object (_self : 'self)
+ inherit [_] VisitorsRuntime.iter
+
+ method visit_var : 'env -> var -> unit = fun _ _ -> ()
+
+ method visit_mplace : 'env -> mplace -> unit = fun _ _ -> ()
+
+ method visit_ty : 'env -> ty -> unit = fun _ _ -> ()
+ end
+
+(** Ancestor for [map_var_or_dummy] visitor *)
+class ['self] map_value_base =
+ object (_self : 'self)
+ inherit [_] VisitorsRuntime.map
+
+ method visit_var : 'env -> var -> var = fun _ x -> x
+
+ method visit_mplace : 'env -> mplace -> mplace = fun _ x -> x
+
+ method visit_ty : 'env -> ty -> ty = fun _ x -> x
+ end
+
+(** Ancestor for [reduce_var_or_dummy] visitor *)
+class virtual ['self] reduce_value_base =
+ object (self : 'self)
+ inherit [_] VisitorsRuntime.reduce
+
+ method visit_var : 'env -> var -> 'a = fun _ _ -> self#zero
+
+ method visit_mplace : 'env -> mplace -> 'a = fun _ _ -> self#zero
+
+ method visit_ty : 'env -> ty -> 'a = fun _ _ -> self#zero
+ end
+
+type var_or_dummy =
+ | Var of var * mplace option
+ | Dummy (** Ignored value: `_`. *)
+[@@deriving
+ visitors
+ {
+ name = "iter_var_or_dummy";
+ variety = "iter";
+ ancestors = [ "iter_value_base" ];
+ nude = true (* Don't inherit [VisitorsRuntime.iter] *);
+ concrete = true;
+ polymorphic = false;
+ },
+ visitors
+ {
+ name = "map_var_or_dummy";
+ variety = "map";
+ ancestors = [ "map_value_base" ];
+ nude = true (* Don't inherit [VisitorsRuntime.map] *);
+ concrete = true;
+ polymorphic = false;
+ },
+ visitors
+ {
+ name = "reduce_var_or_dummy";
+ variety = "reduce";
+ ancestors = [ "reduce_value_base" ];
+ nude = true (* Don't inherit [VisitorsRuntime.reduce] *);
+ polymorphic = false;
+ }]
+
+(** A left value (which appears on the left of assignments *)
+type lvalue = LvVar of var_or_dummy | LvAdt of adt_lvalue
+
+and adt_lvalue = {
+ variant_id : (VariantId.id option[@opaque]);
+ field_values : typed_lvalue list;
+}
+
+and typed_lvalue = { value : lvalue; ty : ty }
+[@@deriving
+ visitors
+ {
+ name = "iter_typed_lvalue";
+ variety = "iter";
+ ancestors = [ "iter_var_or_dummy" ];
+ nude = true (* Don't inherit [VisitorsRuntime.iter] *);
+ concrete = true;
+ polymorphic = false;
+ },
+ visitors
+ {
+ name = "map_typed_lvalue";
+ variety = "map";
+ ancestors = [ "map_var_or_dummy" ];
+ nude = true (* Don't inherit [VisitorsRuntime.iter] *);
+ concrete = true;
+ polymorphic = false;
+ },
+ visitors
+ {
+ name = "reduce_typed_lvalue";
+ variety = "reduce";
+ ancestors = [ "reduce_var_or_dummy" ];
+ nude = true (* Don't inherit [VisitorsRuntime.iter] *);
+ polymorphic = false;
+ }]
+
type rvalue =
| RvConcrete of constant_value
| RvPlace of place
@@ -159,17 +253,15 @@ type call = {
args_mplaces : mplace option list; (** Meta data *)
}
+(* TODO: we might want to merge Call and Assign *)
type let_bindings =
- | Call of (typed_lvalue * mplace option) list * call
+ | Call of typed_lvalue * call
(** The called function and the tuple of returned values. *)
- | Assign of typed_lvalue * mplace option * typed_rvalue * mplace option
- (** Variable assignment: the introduced pattern and the place we read *)
- | Deconstruct of
- (var_or_dummy * mplace option) list
- * (TypeDefId.id * VariantId.id) option
- * typed_rvalue
- * mplace option
- (** This is used in two cases.
+ | Assign of typed_lvalue * typed_rvalue * mplace option
+ (** Variable assignment: the introduced pattern and the place we read.
+
+ We are quite general for the left-value on purpose; this is used
+ in several situations:
1. When deconstructing a tuple:
```
@@ -186,18 +278,13 @@ type let_bindings =
...
```
- Later, depending on the language we extract to, we can eventually
- update it to something like this (for F*, for instance):
+ Note that later, depending on the language we extract to, we can
+ eventually update it to something like this (for F*, for instance):
```
let x = Cons?.v ls in
let tl = Cons?.tl ls in
...
```
-
- Note that we prefer not handling this case through a match.
-
- TODO: why don't we merge this with Assign? It would make things a lot
- simpler (before: introduce general ADTs in lvalue).
*)
(** Meta-information stored in the AST *)
@@ -210,6 +297,8 @@ class ['self] iter_expression_base =
method visit_typed_rvalue : 'env -> typed_rvalue -> unit = fun _ _ -> ()
+ method visit_typed_lvalue : 'env -> typed_lvalue -> unit = fun _ _ -> ()
+
method visit_let_bindings : 'env -> let_bindings -> unit = fun _ _ -> ()
method visit_mplace : 'env -> mplace -> unit = fun _ _ -> ()
@@ -233,6 +322,9 @@ class ['self] map_expression_base =
method visit_typed_rvalue : 'env -> typed_rvalue -> typed_rvalue =
fun _ x -> x
+ method visit_typed_lvalue : 'env -> typed_lvalue -> typed_lvalue =
+ fun _ x -> x
+
method visit_let_bindings : 'env -> let_bindings -> let_bindings =
fun _ x -> x
@@ -274,11 +366,7 @@ and switch_body =
| SwitchInt of T.integer_type * (scalar_value * expression) list * expression
| Match of match_branch list
-and match_branch = {
- variant_id : VariantId.id;
- vars : var_or_dummy list;
- branch : expression;
-}
+and match_branch = { pat : typed_lvalue; branch : expression }
[@@deriving
visitors
{
diff --git a/src/PureMicroPasses.ml b/src/PureMicroPasses.ml
index fa2a6e16..80c35124 100644
--- a/src/PureMicroPasses.ml
+++ b/src/PureMicroPasses.ml
@@ -69,16 +69,34 @@ type pn_ctx = string VarId.Map.t
*)
let compute_pretty_names (def : fun_def) : fun_def =
(* Small helpers *)
+ (*
+ * When we do branchings, we need to merge (the constraints saved in) the
+ * contexts returned by the different branches.
+ *
+ * Note that by doing so, some mappings from var id to name
+ * in one context may be overriden by the ones in the other context.
+ *
+ * This should be ok because:
+ * - generally, the overriden variables should have been introduced *inside*
+ * the branches, in which case we don't care
+ * - or they were introduced before, in which case the naming should generally
+ * be consistent? In the worse case, it isn't, but it leads only to less
+ * readable code, not to unsoundness. This case should be pretty rare,
+ * also.
+ *)
+ let merge_ctxs (ctx0 : pn_ctx) (ctx1 : pn_ctx) : pn_ctx =
+ VarId.Map.fold (fun id name ctx -> VarId.Map.add id name ctx) ctx0 ctx1
+ in
+ let merge_ctxs_ls (ctxs : pn_ctx list) : pn_ctx =
+ List.fold_left (fun ctx0 ctx1 -> merge_ctxs ctx0 ctx1) VarId.Map.empty ctxs
+ in
+
let add_var (ctx : pn_ctx) (v : var) : pn_ctx =
assert (not (VarId.Map.mem v.id ctx));
match v.basename with
| None -> ctx
| Some name -> VarId.Map.add v.id name ctx
in
- let add_var_or_dummy (ctx : pn_ctx) (v : var_or_dummy) : pn_ctx =
- match v with Dummy -> ctx | Var v -> add_var ctx v
- in
- let add_var_or_dummy_list ctx ls = List.fold_left add_var_or_dummy ctx ls in
let update_var (ctx : pn_ctx) (v : var) : var =
match v.basename with
| Some _ -> v
@@ -87,17 +105,15 @@ let compute_pretty_names (def : fun_def) : fun_def =
| None -> v
| Some basename -> { v with basename = Some basename })
in
- let update_var_or_dummy (ctx : pn_ctx) (v : var_or_dummy) : var_or_dummy =
- match v with Dummy -> Dummy | Var v -> Var (update_var ctx v)
- in
- let update_var_or_dummy_list ctx = List.map (update_var_or_dummy ctx) in
- let update_typed_lvalue ctx (lv : typed_lvalue) =
- let value =
- match lv.value with
- | LvVar v -> LvVar (update_var_or_dummy ctx v)
- | v -> v
+ let update_typed_lvalue ctx (lv : typed_lvalue) : typed_lvalue =
+ let obj =
+ object
+ inherit [_] map_typed_lvalue
+
+ method! visit_var _ v = update_var ctx v
+ end
in
- { lv with value }
+ obj#visit_typed_lvalue () lv
in
let add_constraint (mp : mplace) (var_id : VarId.id) (ctx : pn_ctx) : pn_ctx =
@@ -122,48 +138,20 @@ let compute_pretty_names (def : fun_def) : fun_def =
(fun ctx (mp, rv) -> add_opt_right_constraint mp rv ctx)
ctx rvs
in
- let add_left_constraint_var_or_dummy (mp : mplace option) (v : var_or_dummy)
- (ctx : pn_ctx) : pn_ctx =
- let ctx = add_var_or_dummy ctx v in
- match (v, mp) with Var v, Some mp -> add_constraint mp v.id ctx | _ -> ctx
- in
- let add_left_constraint_typed_value (mp : mplace option) (lv : typed_lvalue)
- (ctx : pn_ctx) : pn_ctx =
- match lv.value with
- | LvTuple _ | LvVar Dummy -> ctx
- | LvVar v -> add_left_constraint_var_or_dummy mp v ctx
- in
- let add_left_constraint_var_or_dummy_list ctx lvs =
- List.fold_left
- (fun ctx (v, mp) -> add_left_constraint_var_or_dummy mp v ctx)
- ctx lvs
- in
- let add_left_constraint_typed_value_list ctx lvs =
- List.fold_left
- (fun ctx (v, mp) -> add_left_constraint_typed_value mp v ctx)
- ctx lvs
- in
+ let add_left_constraint (lv : typed_lvalue) (ctx : pn_ctx) : pn_ctx =
+ let obj =
+ object (self)
+ inherit [_] reduce_typed_lvalue
- (*
- * When we do branchings, we need to merge (the constraints saved in) the
- * contexts returned by the different branches.
- *
- * Note that by doing so, some mappings from var id to name
- * in one context may be overriden by the ones in the other context.
- *
- * This should be ok because:
- * - generally, the overriden variables should have been introduced *inside*
- * the branches, in which case we don't care
- * - or they were introduced before, in which case the naming should generally
- * be consistent? In the worse case, it isn't, but it leads only to less
- * readable code, not to unsoundness. This case should be pretty rare,
- * also.
- *)
- let merge_ctxs (ctx0 : pn_ctx) (ctx1 : pn_ctx) : pn_ctx =
- VarId.Map.fold (fun id name ctx -> VarId.Map.add id name ctx) ctx0 ctx1
- in
- let merge_ctxs_ls (ctxs : pn_ctx list) : pn_ctx =
- List.fold_left (fun ctx0 ctx1 -> merge_ctxs ctx0 ctx1) VarId.Map.empty ctxs
+ method zero _ = VarId.Map.empty
+
+ method plus ctx0 ctx1 _ = merge_ctxs (ctx0 ()) (ctx1 ())
+
+ method! visit_var _ v () = add_var (self#zero ()) v
+ end
+ in
+ let ctx1 = obj#visit_typed_lvalue () lv () in
+ merge_ctxs ctx ctx1
in
(* *)
@@ -178,31 +166,21 @@ let compute_pretty_names (def : fun_def) : fun_def =
and update_let (lb : let_bindings) (e : expression) (ctx : pn_ctx) :
pn_ctx * expression =
match lb with
- | Call (lvs, call) ->
+ | Call (lv, call) ->
let ctx =
add_opt_right_constraint_list ctx
(List.combine call.args_mplaces call.args)
in
- let ctx = add_left_constraint_typed_value_list ctx lvs in
- let ctx, e = update_expression e ctx in
- let lvs =
- List.map (fun (v, mp) -> (update_typed_lvalue ctx v, mp)) lvs
- in
- (ctx, Let (Call (lvs, call), e))
- | Assign (lv, lmp, rv, rmp) ->
- let ctx = add_left_constraint_typed_value lmp lv ctx in
- let ctx = add_opt_right_constraint rmp rv ctx in
+ let ctx = add_left_constraint lv ctx in
let ctx, e = update_expression e ctx in
let lv = update_typed_lvalue ctx lv in
- (ctx, Let (Assign (lv, lmp, rv, rmp), e))
- | Deconstruct (lvs, opt_variant_id, rv, rmp) ->
- let ctx = add_left_constraint_var_or_dummy_list ctx lvs in
+ (ctx, Let (Call (lv, call), e))
+ | Assign (lv, rv, rmp) ->
+ let ctx = add_left_constraint lv ctx in
let ctx = add_opt_right_constraint rmp rv ctx in
let ctx, e = update_expression e ctx in
- let lvs =
- List.map (fun (v, mp) -> (update_var_or_dummy ctx v, mp)) lvs
- in
- (ctx, Let (Deconstruct (lvs, opt_variant_id, rv, rmp), e))
+ let lv = update_typed_lvalue ctx lv in
+ (ctx, Let (Assign (lv, rv, rmp), e))
(* *)
and update_switch_body (scrut : typed_rvalue) (mp : mplace option)
(body : switch_body) (ctx : pn_ctx) : pn_ctx * expression =
@@ -232,10 +210,10 @@ let compute_pretty_names (def : fun_def) : fun_def =
let ctx_branches_ls =
List.map
(fun br ->
- let ctx = add_var_or_dummy_list ctx br.vars in
+ let ctx = add_left_constraint br.pat ctx in
let ctx, branch = update_expression br.branch ctx in
- let vars = update_var_or_dummy_list ctx br.vars in
- (ctx, { br with branch; vars }))
+ let pat = update_typed_lvalue ctx br.pat in
+ (ctx, { pat; branch }))
branches
in
let ctxs, branches = List.split ctx_branches_ls in
@@ -268,7 +246,7 @@ let remove_meta (def : fun_def) : fun_def =
object
inherit [_] map_expression as super
- method! visit_Meta env meta e = super#visit_expression env e
+ method! visit_Meta env _ e = super#visit_expression env e
end
in
let body = obj#visit_expression () def.body in
diff --git a/src/SymbolicToPure.ml b/src/SymbolicToPure.ml
index d46d8386..7fd72926 100644
--- a/src/SymbolicToPure.ml
+++ b/src/SymbolicToPure.ml
@@ -24,11 +24,22 @@ let mk_typed_rvalue_from_var (v : var) : typed_rvalue =
{ value; ty }
(* TODO: move *)
-let mk_typed_lvalue_from_var (v : var) : typed_lvalue =
- let value = LvVar (Var v) in
+let mk_typed_lvalue_from_var (v : var) (mp : mplace option) : typed_lvalue =
+ let value = LvVar (Var (v, mp)) in
let ty = v.ty in
{ value; ty }
+let mk_tuple_lvalue (vl : typed_lvalue list) : typed_lvalue =
+ let tys = List.map (fun (v : typed_lvalue) -> v.ty) vl in
+ let ty = Adt (T.Tuple, tys) in
+ let value = LvAdt { variant_id = None; field_values = vl } in
+ { value; ty }
+
+let mk_adt_lvalue (adt_ty : ty) (variant_id : VariantId.id)
+ (vl : typed_lvalue list) : typed_lvalue =
+ let value = LvAdt { variant_id = Some variant_id; field_values = vl } in
+ { value; ty = adt_ty }
+
let ty_as_integer (t : ty) : T.integer_type =
match t with Integer int_ty -> int_ty | _ -> failwith "Unreachable"
@@ -786,7 +797,7 @@ let rec typed_avalue_to_given_back (av : V.typed_avalue) (ctx : bs_ctx) :
assert (variant_id = None);
if field_values = [] then (ctx, None)
else
- let value = LvTuple field_values in
+ let value = LvAdt { variant_id = None; field_values } in
let ty = ctx_translate_fwd_ty ctx av.ty in
let lv : typed_lvalue = { value; ty } in
(ctx, Some lv))
@@ -822,7 +833,7 @@ and aborrow_content_to_given_back (bc : V.aborrow_content) (ctx : bs_ctx) :
| AEndedMutBorrow (mv, _) ->
(* Return the meta-value *)
let ctx, var = fresh_var_for_symbolic_value mv ctx in
- (ctx, Some (mk_typed_lvalue_from_var var))
+ (ctx, Some (mk_typed_lvalue_from_var var None))
| AEndedIgnoredMutBorrow _ ->
(* This happens with nested borrows: we need to dive in *)
raise Unimplemented
@@ -844,7 +855,7 @@ and aproj_to_given_back (aproj : V.aproj) (ctx : bs_ctx) :
| AEndedProjBorrows mv ->
(* Return the meta-value *)
let ctx, var = fresh_var_for_symbolic_value mv ctx in
- (ctx, Some (mk_typed_lvalue_from_var var))
+ (ctx, Some (mk_typed_lvalue_from_var var None))
| AIgnoredProjBorrows | AProjLoans (_, _) | AProjBorrows (_, _) ->
failwith "Unreachable"
@@ -947,7 +958,7 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) :
(* Translate the next expression *)
let e = translate_expression e ctx in
(* Put together *)
- Let (Call ([ (mk_typed_lvalue_from_var dest, dest_mplace) ], call), e)
+ Let (Call (mk_typed_lvalue_from_var dest dest_mplace, call), e)
and translate_end_abstraction (abs : V.abs) (e : S.expression) (ctx : bs_ctx) :
expression =
@@ -1002,7 +1013,7 @@ and translate_end_abstraction (abs : V.abs) (e : S.expression) (ctx : bs_ctx) :
(* Generate the assignemnts *)
List.fold_right
(fun (var, value) e ->
- Let (Assign (mk_typed_lvalue_from_var var, None, value, None), e))
+ Let (Assign (mk_typed_lvalue_from_var var None, value, None), e))
variables_values e
| V.FunCall ->
let call_info = V.FunCallId.Map.find abs.call_id ctx.calls in
@@ -1026,6 +1037,7 @@ and translate_end_abstraction (abs : V.abs) (e : S.expression) (ctx : bs_ctx) :
(* Retrieve the values given back by this function: those are the output
* values *)
let ctx, outputs = abs_to_given_back abs ctx in
+ let output = mk_tuple_lvalue outputs in
(* Sanity check: the inputs and outputs have the proper number and the proper type *)
let fun_id =
match call.call_id with
@@ -1057,8 +1069,7 @@ and translate_end_abstraction (abs : V.abs) (e : S.expression) (ctx : bs_ctx) :
(* Put everything together *)
let args_mplaces = List.map (fun _ -> None) inputs in
let call = { func; type_params; args = inputs; args_mplaces } in
- let outputs = List.map (fun x -> (x, None)) outputs in
- Let (Call (outputs, call), e)
+ Let (Call (output, call), e)
| V.SynthRet ->
(* If we end the abstraction which consumed the return value of the function
* we are synthesizing, we get back the borrows which were inside. Those borrows
@@ -1111,9 +1122,7 @@ and translate_end_abstraction (abs : V.abs) (e : S.expression) (ctx : bs_ctx) :
(* Generate the assignments *)
List.fold_right
(fun (given_back, input_var) e ->
- Let
- ( Assign (given_back, None, mk_typed_rvalue_from_var input_var, None),
- e ))
+ Let (Assign (given_back, mk_typed_rvalue_from_var input_var, None), e))
given_back_inputs e
and translate_expansion (p : S.mplace option) (sv : V.symbolic_value)
@@ -1137,7 +1146,7 @@ and translate_expansion (p : S.mplace option) (sv : V.symbolic_value)
let e = translate_expression e ctx in
Let
( Assign
- (mk_typed_lvalue_from_var var, None, scrutinee, scrutinee_mplace),
+ (mk_typed_lvalue_from_var var None, scrutinee, scrutinee_mplace),
e )
| SeAdt _ ->
(* Should be in the [ExpandAdt] case *)
@@ -1158,14 +1167,11 @@ and translate_expansion (p : S.mplace option) (sv : V.symbolic_value)
if is_enum then
(* This is an enumeration: introduce an [ExpandEnum] let-binding *)
let variant_id = Option.get variant_id in
- let vars = List.map (fun x -> (Var x, None)) vars in
- Let
- ( Deconstruct
- ( vars,
- Some (adt_id, variant_id),
- scrutinee,
- scrutinee_mplace ),
- branch )
+ let lvars =
+ List.map (fun v -> mk_typed_lvalue_from_var v None) vars
+ in
+ let lv = mk_adt_lvalue scrutinee.ty variant_id lvars in
+ Let (Assign (lv, scrutinee, scrutinee_mplace), branch)
else
(* This is not an enumeration: introduce let-bindings for every
* field.
@@ -1187,12 +1193,16 @@ and translate_expansion (p : S.mplace option) (sv : V.symbolic_value)
let field_proj = gen_field_proj fid var in
Let
( Assign
- (mk_typed_lvalue_from_var var, None, field_proj, None),
+ (mk_typed_lvalue_from_var var None, field_proj, None),
e ))
id_var_pairs branch
| T.Tuple ->
- let vars = List.map (fun x -> (Var x, None)) vars in
- Let (Deconstruct (vars, None, scrutinee, scrutinee_mplace), branch)
+ let vars =
+ List.map (fun x -> mk_typed_lvalue_from_var x None) vars
+ in
+ Let
+ ( Assign (mk_tuple_lvalue vars, scrutinee, scrutinee_mplace),
+ branch )
| T.Assumed T.Box ->
(* There should be exactly one variable *)
let var =
@@ -1202,8 +1212,7 @@ and translate_expansion (p : S.mplace option) (sv : V.symbolic_value)
* identity when extracted (`box a == a`) *)
Let
( Assign
- ( mk_typed_lvalue_from_var var,
- None,
+ ( mk_typed_lvalue_from_var var None,
scrutinee,
scrutinee_mplace ),
branch ))
@@ -1214,9 +1223,13 @@ and translate_expansion (p : S.mplace option) (sv : V.symbolic_value)
(* There *must* be a variant id - otherwise there can't be several branches *)
let variant_id = Option.get variant_id in
let ctx, vars = fresh_vars_for_symbolic_values svl ctx in
- let vars = List.map (fun x -> Var x) vars in
+ let vars =
+ List.map (fun x -> mk_typed_lvalue_from_var x None) vars
+ in
+ let pat_ty = scrutinee.ty in
+ let pat = mk_adt_lvalue pat_ty variant_id vars in
let branch = translate_expression branch ctx in
- { variant_id; vars; branch }
+ { pat; branch }
in
let branches =
List.map (fun (vid, svl, e) -> translate_branch vid svl e) branches