summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/PureMicroPasses.ml2
-rw-r--r--src/PureUtils.ml34
-rw-r--r--src/SymbolicToPure.ml45
-rw-r--r--src/Translate.ml2
4 files changed, 51 insertions, 32 deletions
diff --git a/src/PureMicroPasses.ml b/src/PureMicroPasses.ml
index e2894fe9..6b42e328 100644
--- a/src/PureMicroPasses.ml
+++ b/src/PureMicroPasses.ml
@@ -650,7 +650,7 @@ let to_monadic (def : fun_def) : fun_def =
mk_result_ty out_ty
| Some _, outputs ->
(* Backward function: we have to group them *)
- mk_result_ty (mk_tuple_ty outputs)
+ mk_result_ty (mk_simpl_tuple_ty outputs)
| _ -> failwith "Unreachable"
in
let outputs = [ output_ty ] in
diff --git a/src/PureUtils.ml b/src/PureUtils.ml
index 9f3bd5ef..c19d7914 100644
--- a/src/PureUtils.ml
+++ b/src/PureUtils.ml
@@ -42,7 +42,12 @@ let binop_can_fail (binop : E.binop) : bool =
let mk_place_from_var (v : var) : place = { var = v.id; projection = [] }
-let mk_tuple_ty (tys : ty list) : ty = Adt (Tuple, tys)
+(** Make a "simplified" tuple type from a list of types:
+ - if there is exactly one type, just return it
+ - if there is > one type: wrap them in a tuple
+ *)
+let mk_simpl_tuple_ty (tys : ty list) : ty =
+ match tys with [ ty ] -> ty | _ -> Adt (Tuple, tys)
let unit_ty : ty = Adt (Tuple, [])
@@ -61,11 +66,28 @@ let mk_typed_lvalue_from_var (v : var) (mp : mplace option) : typed_lvalue =
let ty = v.ty in
{ value; ty }
-let mk_tuple_lvalue (vl : typed_lvalue list) : typed_lvalue =
- let tys = List.map (fun (v : typed_lvalue) -> v.ty) vl in
- let ty = Adt (Tuple, tys) in
- let value = LvAdt { variant_id = None; field_values = vl } in
- { value; ty }
+(** Make a "simplified" tuple value from a list of values:
+ - if there is exactly one value, just return it
+ - if there is > one value: wrap them in a tuple
+ *)
+let mk_simpl_tuple_lvalue (vl : typed_lvalue list) : typed_lvalue =
+ match vl with
+ | [ v ] -> v
+ | _ ->
+ let tys = List.map (fun (v : typed_lvalue) -> v.ty) vl in
+ let ty = Adt (Tuple, tys) in
+ let value = LvAdt { variant_id = None; field_values = vl } in
+ { value; ty }
+
+(** Similar to [mk_simpl_tuple_lvalue] *)
+let mk_simpl_tuple_rvalue (vl : typed_rvalue list) : typed_rvalue =
+ match vl with
+ | [ v ] -> v
+ | _ ->
+ let tys = List.map (fun (v : typed_rvalue) -> v.ty) vl in
+ let ty = Adt (Tuple, tys) in
+ let value = RvAdt { variant_id = None; field_values = vl } in
+ { value; ty }
let mk_adt_lvalue (adt_ty : ty) (variant_id : VariantId.id)
(vl : typed_lvalue list) : typed_lvalue =
diff --git a/src/SymbolicToPure.ml b/src/SymbolicToPure.ml
index 59f69d17..9a611899 100644
--- a/src/SymbolicToPure.ml
+++ b/src/SymbolicToPure.ml
@@ -183,7 +183,7 @@ let rec translate_sty (ty : T.sty) : ty =
let tys = List.map translate tys in
match type_id with
| T.AdtId adt_id -> Adt (AdtId adt_id, tys)
- | T.Tuple -> Adt (Tuple, tys)
+ | T.Tuple -> mk_simpl_tuple_ty tys
| T.Assumed T.Box -> (
match tys with
| [ ty ] -> ty
@@ -253,7 +253,10 @@ let rec translate_fwd_ty (types_infos : TA.type_infos) (ty : 'r T.ty) : ty =
(* No general parametricity for now *)
assert (not (List.exists (TypesUtils.ty_has_borrows types_infos) tys));
Adt (AdtId adt_id, t_tys)
- | Tuple -> Adt (Tuple, t_tys)
+ | Tuple ->
+ (* Note that if there is exactly one type, [mk_simpl_tuple_ty] is the
+ identity *)
+ mk_simpl_tuple_ty t_tys
| T.Assumed T.Box -> (
(* No general parametricity for now *)
assert (not (List.exists (TypesUtils.ty_has_borrows types_infos) tys));
@@ -317,7 +320,12 @@ let rec translate_back_ty (types_infos : TA.type_infos)
| T.Tuple -> (
(* Tuples can contain borrows (which we eliminated) *)
let tys_t = List.filter_map translate tys in
- match tys_t with [] -> None | _ -> Some (Adt (Tuple, tys_t))))
+ match tys_t with
+ | [] -> None
+ | _ ->
+ (* Note that if there is exactly one type, [mk_simpl_tuple_ty]
+ * is the identity *)
+ Some (mk_simpl_tuple_ty tys_t)))
| TypeVar vid -> wrap (TypeVar vid)
| Bool -> wrap Bool
| Char -> wrap Char
@@ -605,16 +613,11 @@ let rec typed_avalue_to_consumed (ctx : bs_ctx) (av : V.typed_avalue) :
None
| T.Tuple ->
(* Return *)
- let variant_id = adt_v.variant_id in
if field_values = [] then None
else
- let value = RvAdt { variant_id; field_values } in
- let tys =
- List.map (fun (fv : typed_rvalue) -> fv.ty) field_values
- in
- (* TODO: don't use a tuple wrapper if exactly one value *)
- let ty = Adt (Tuple, tys) in
- let rv = { value; ty } in
+ (* Note that if there is exactly one field value,
+ * [mk_simpl_tuple_rvalue] is the identity *)
+ let rv = mk_simpl_tuple_rvalue field_values in
Some rv)
| ABottom -> failwith "Unreachable"
| ALoan lc -> aloan_content_to_consumed ctx lc
@@ -747,13 +750,9 @@ let rec typed_avalue_to_given_back (mp : mplace option) (av : V.typed_avalue)
assert (variant_id = None);
if field_values = [] then (ctx, None)
else
- let value = LvAdt { variant_id = None; field_values } in
- let tys =
- List.map (fun (fv : typed_lvalue) -> fv.ty) field_values
- in
- (* TODO: don't use a tuple wrapper if exactly one value *)
- let ty = Adt (Tuple, tys) in
- let lv : typed_lvalue = { value; ty } in
+ (* Note that if there is exactly one field value, [mk_simpl_tuple_lvalue]
+ * is the identity *)
+ let lv = mk_simpl_tuple_lvalue field_values in
(ctx, Some lv))
| ABottom -> failwith "Unreachable"
| ALoan lc -> aloan_content_to_given_back mp lc ctx
@@ -897,10 +896,7 @@ and translate_return (opt_v : V.typed_value option) (ctx : bs_ctx) : texpression
T.RegionGroupId.Map.find bid ctx.backward_outputs
in
let field_values = List.map mk_typed_rvalue_from_var backward_outputs in
- let ret_value = RvAdt { variant_id = None; field_values } in
- let ret_tys = List.map (fun (v : typed_rvalue) -> v.ty) field_values in
- let ret_ty = Adt (Tuple, ret_tys) in
- let ret_value : typed_rvalue = { value = ret_value; ty = ret_ty } in
+ let ret_value = mk_simpl_tuple_rvalue field_values in
let ret_value = mk_result_return_rvalue ret_value in
let e = Value (ret_value, None) in
let ty = ret_value.ty in
@@ -1043,7 +1039,7 @@ and translate_end_abstraction (abs : V.abs) (e : S.expression) (ctx : bs_ctx) :
List.append (List.map translate_opt_mplace call.args_places) [ None ]
in
let ctx, outputs = abs_to_given_back output_mpl abs ctx in
- let output = mk_tuple_lvalue outputs in
+ let output = mk_simpl_tuple_lvalue outputs in
(* Sanity check: the inputs and outputs have the proper number and the proper type *)
let fun_id =
match call.call_id with
@@ -1252,7 +1248,8 @@ and translate_expansion (p : S.mplace option) (sv : V.symbolic_value)
List.map (fun x -> mk_typed_lvalue_from_var x None) vars
in
let monadic = false in
- mk_let monadic (mk_tuple_lvalue vars)
+ mk_let monadic
+ (mk_simpl_tuple_lvalue vars)
(mk_value_expression scrutinee scrutinee_mplace)
branch
| T.Assumed T.Box ->
diff --git a/src/Translate.ml b/src/Translate.ml
index a5189606..252be079 100644
--- a/src/Translate.ml
+++ b/src/Translate.ml
@@ -174,7 +174,7 @@ let translate_function_to_pure (config : C.partial_config)
let backward_output_tys =
List.map (fun (v : Pure.var) -> v.ty) backward_outputs
in
- let backward_ret_ty = mk_tuple_ty backward_output_tys in
+ let backward_ret_ty = mk_simpl_tuple_ty backward_output_tys in
let backward_inputs =
T.RegionGroupId.Map.singleton back_id backward_inputs
in