diff options
author | Son Ho | 2022-02-04 09:56:44 +0100 |
---|---|---|
committer | Son Ho | 2022-02-04 09:56:44 +0100 |
commit | 527d828067bc4780641c979ddc880f98322e4c31 (patch) | |
tree | 1a4856ac9bb88b80c05245a867038c6691abecca /src | |
parent | 540c13ac94e00fae062cd328903711ea9693ddfc (diff) |
Update SymbolicToPure so that we don't construct tuples with exactly one
field
Diffstat (limited to '')
-rw-r--r-- | src/PureMicroPasses.ml | 2 | ||||
-rw-r--r-- | src/PureUtils.ml | 34 | ||||
-rw-r--r-- | src/SymbolicToPure.ml | 45 | ||||
-rw-r--r-- | src/Translate.ml | 2 |
4 files changed, 51 insertions, 32 deletions
diff --git a/src/PureMicroPasses.ml b/src/PureMicroPasses.ml index e2894fe9..6b42e328 100644 --- a/src/PureMicroPasses.ml +++ b/src/PureMicroPasses.ml @@ -650,7 +650,7 @@ let to_monadic (def : fun_def) : fun_def = mk_result_ty out_ty | Some _, outputs -> (* Backward function: we have to group them *) - mk_result_ty (mk_tuple_ty outputs) + mk_result_ty (mk_simpl_tuple_ty outputs) | _ -> failwith "Unreachable" in let outputs = [ output_ty ] in diff --git a/src/PureUtils.ml b/src/PureUtils.ml index 9f3bd5ef..c19d7914 100644 --- a/src/PureUtils.ml +++ b/src/PureUtils.ml @@ -42,7 +42,12 @@ let binop_can_fail (binop : E.binop) : bool = let mk_place_from_var (v : var) : place = { var = v.id; projection = [] } -let mk_tuple_ty (tys : ty list) : ty = Adt (Tuple, tys) +(** Make a "simplified" tuple type from a list of types: + - if there is exactly one type, just return it + - if there is > one type: wrap them in a tuple + *) +let mk_simpl_tuple_ty (tys : ty list) : ty = + match tys with [ ty ] -> ty | _ -> Adt (Tuple, tys) let unit_ty : ty = Adt (Tuple, []) @@ -61,11 +66,28 @@ let mk_typed_lvalue_from_var (v : var) (mp : mplace option) : typed_lvalue = let ty = v.ty in { value; ty } -let mk_tuple_lvalue (vl : typed_lvalue list) : typed_lvalue = - let tys = List.map (fun (v : typed_lvalue) -> v.ty) vl in - let ty = Adt (Tuple, tys) in - let value = LvAdt { variant_id = None; field_values = vl } in - { value; ty } +(** Make a "simplified" tuple value from a list of values: + - if there is exactly one value, just return it + - if there is > one value: wrap them in a tuple + *) +let mk_simpl_tuple_lvalue (vl : typed_lvalue list) : typed_lvalue = + match vl with + | [ v ] -> v + | _ -> + let tys = List.map (fun (v : typed_lvalue) -> v.ty) vl in + let ty = Adt (Tuple, tys) in + let value = LvAdt { variant_id = None; field_values = vl } in + { value; ty } + +(** Similar to [mk_simpl_tuple_lvalue] *) +let mk_simpl_tuple_rvalue (vl : typed_rvalue list) : typed_rvalue = + match vl with + | [ v ] -> v + | _ -> + let tys = List.map (fun (v : typed_rvalue) -> v.ty) vl in + let ty = Adt (Tuple, tys) in + let value = RvAdt { variant_id = None; field_values = vl } in + { value; ty } let mk_adt_lvalue (adt_ty : ty) (variant_id : VariantId.id) (vl : typed_lvalue list) : typed_lvalue = diff --git a/src/SymbolicToPure.ml b/src/SymbolicToPure.ml index 59f69d17..9a611899 100644 --- a/src/SymbolicToPure.ml +++ b/src/SymbolicToPure.ml @@ -183,7 +183,7 @@ let rec translate_sty (ty : T.sty) : ty = let tys = List.map translate tys in match type_id with | T.AdtId adt_id -> Adt (AdtId adt_id, tys) - | T.Tuple -> Adt (Tuple, tys) + | T.Tuple -> mk_simpl_tuple_ty tys | T.Assumed T.Box -> ( match tys with | [ ty ] -> ty @@ -253,7 +253,10 @@ let rec translate_fwd_ty (types_infos : TA.type_infos) (ty : 'r T.ty) : ty = (* No general parametricity for now *) assert (not (List.exists (TypesUtils.ty_has_borrows types_infos) tys)); Adt (AdtId adt_id, t_tys) - | Tuple -> Adt (Tuple, t_tys) + | Tuple -> + (* Note that if there is exactly one type, [mk_simpl_tuple_ty] is the + identity *) + mk_simpl_tuple_ty t_tys | T.Assumed T.Box -> ( (* No general parametricity for now *) assert (not (List.exists (TypesUtils.ty_has_borrows types_infos) tys)); @@ -317,7 +320,12 @@ let rec translate_back_ty (types_infos : TA.type_infos) | T.Tuple -> ( (* Tuples can contain borrows (which we eliminated) *) let tys_t = List.filter_map translate tys in - match tys_t with [] -> None | _ -> Some (Adt (Tuple, tys_t)))) + match tys_t with + | [] -> None + | _ -> + (* Note that if there is exactly one type, [mk_simpl_tuple_ty] + * is the identity *) + Some (mk_simpl_tuple_ty tys_t))) | TypeVar vid -> wrap (TypeVar vid) | Bool -> wrap Bool | Char -> wrap Char @@ -605,16 +613,11 @@ let rec typed_avalue_to_consumed (ctx : bs_ctx) (av : V.typed_avalue) : None | T.Tuple -> (* Return *) - let variant_id = adt_v.variant_id in if field_values = [] then None else - let value = RvAdt { variant_id; field_values } in - let tys = - List.map (fun (fv : typed_rvalue) -> fv.ty) field_values - in - (* TODO: don't use a tuple wrapper if exactly one value *) - let ty = Adt (Tuple, tys) in - let rv = { value; ty } in + (* Note that if there is exactly one field value, + * [mk_simpl_tuple_rvalue] is the identity *) + let rv = mk_simpl_tuple_rvalue field_values in Some rv) | ABottom -> failwith "Unreachable" | ALoan lc -> aloan_content_to_consumed ctx lc @@ -747,13 +750,9 @@ let rec typed_avalue_to_given_back (mp : mplace option) (av : V.typed_avalue) assert (variant_id = None); if field_values = [] then (ctx, None) else - let value = LvAdt { variant_id = None; field_values } in - let tys = - List.map (fun (fv : typed_lvalue) -> fv.ty) field_values - in - (* TODO: don't use a tuple wrapper if exactly one value *) - let ty = Adt (Tuple, tys) in - let lv : typed_lvalue = { value; ty } in + (* Note that if there is exactly one field value, [mk_simpl_tuple_lvalue] + * is the identity *) + let lv = mk_simpl_tuple_lvalue field_values in (ctx, Some lv)) | ABottom -> failwith "Unreachable" | ALoan lc -> aloan_content_to_given_back mp lc ctx @@ -897,10 +896,7 @@ and translate_return (opt_v : V.typed_value option) (ctx : bs_ctx) : texpression T.RegionGroupId.Map.find bid ctx.backward_outputs in let field_values = List.map mk_typed_rvalue_from_var backward_outputs in - let ret_value = RvAdt { variant_id = None; field_values } in - let ret_tys = List.map (fun (v : typed_rvalue) -> v.ty) field_values in - let ret_ty = Adt (Tuple, ret_tys) in - let ret_value : typed_rvalue = { value = ret_value; ty = ret_ty } in + let ret_value = mk_simpl_tuple_rvalue field_values in let ret_value = mk_result_return_rvalue ret_value in let e = Value (ret_value, None) in let ty = ret_value.ty in @@ -1043,7 +1039,7 @@ and translate_end_abstraction (abs : V.abs) (e : S.expression) (ctx : bs_ctx) : List.append (List.map translate_opt_mplace call.args_places) [ None ] in let ctx, outputs = abs_to_given_back output_mpl abs ctx in - let output = mk_tuple_lvalue outputs in + let output = mk_simpl_tuple_lvalue outputs in (* Sanity check: the inputs and outputs have the proper number and the proper type *) let fun_id = match call.call_id with @@ -1252,7 +1248,8 @@ and translate_expansion (p : S.mplace option) (sv : V.symbolic_value) List.map (fun x -> mk_typed_lvalue_from_var x None) vars in let monadic = false in - mk_let monadic (mk_tuple_lvalue vars) + mk_let monadic + (mk_simpl_tuple_lvalue vars) (mk_value_expression scrutinee scrutinee_mplace) branch | T.Assumed T.Box -> diff --git a/src/Translate.ml b/src/Translate.ml index a5189606..252be079 100644 --- a/src/Translate.ml +++ b/src/Translate.ml @@ -174,7 +174,7 @@ let translate_function_to_pure (config : C.partial_config) let backward_output_tys = List.map (fun (v : Pure.var) -> v.ty) backward_outputs in - let backward_ret_ty = mk_tuple_ty backward_output_tys in + let backward_ret_ty = mk_simpl_tuple_ty backward_output_tys in let backward_inputs = T.RegionGroupId.Map.singleton back_id backward_inputs in |