summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/ExtractToFStar.ml4
-rw-r--r--src/PureMicroPasses.ml19
-rw-r--r--src/PureUtils.ml24
-rw-r--r--src/SymbolicToPure.ml12
4 files changed, 31 insertions, 28 deletions
diff --git a/src/ExtractToFStar.ml b/src/ExtractToFStar.ml
index 0a06ef73..068448e9 100644
--- a/src/ExtractToFStar.ml
+++ b/src/ExtractToFStar.ml
@@ -1475,8 +1475,8 @@ let extract_unit_test_if_unit_fun (ctx : extraction_ctx) (fmt : F.formatter)
let sg = def.signature in
if
sg.type_params = []
- && (sg.inputs = [ unit_ty ] || sg.inputs = [])
- && sg.outputs = [ mk_result_ty unit_ty ]
+ && (sg.inputs = [ mk_unit_ty ] || sg.inputs = [])
+ && sg.outputs = [ mk_result_ty mk_unit_ty ]
then (
(* Add a break before *)
F.pp_print_break fmt 0 0;
diff --git a/src/PureMicroPasses.ml b/src/PureMicroPasses.ml
index bd7d7766..f76dd2f4 100644
--- a/src/PureMicroPasses.ml
+++ b/src/PureMicroPasses.ml
@@ -923,8 +923,8 @@ let filter_if_backward_with_no_outputs (config : config) (def : fun_decl) :
let return_ty =
if config.use_state_monad then
mk_arrow mk_state_ty
- (mk_result_ty (mk_simpl_tuple_ty [ mk_state_ty; unit_ty ]))
- else mk_result_ty (mk_simpl_tuple_ty [ unit_ty ])
+ (mk_result_ty (mk_simpl_tuple_ty [ mk_state_ty; mk_unit_ty ]))
+ else mk_result_ty (mk_simpl_tuple_ty [ mk_unit_ty ])
in
if
config.filter_useless_functions && Option.is_some def.back_id
@@ -954,7 +954,7 @@ let keep_forward (config : config) (trans : pure_fun_translation) : bool =
* they should be lists of length 1. *)
if
config.filter_useless_functions
- && fwd.signature.outputs = [ mk_result_ty unit_ty ]
+ && fwd.signature.outputs = [ mk_result_ty mk_unit_ty ]
&& backs <> []
then false
else true
@@ -968,11 +968,12 @@ let unit_vars_to_unit (def : fun_decl) : fun_decl =
inherit [_] map_expression as super
method! visit_PatVar _ v mp =
- if v.ty = unit_ty then PatDummy else PatVar (v, mp)
+ if v.ty = mk_unit_ty then PatDummy else PatVar (v, mp)
(** Replace in patterns *)
method! visit_texpression env e =
- if e.ty = unit_ty then unit_rvalue else super#visit_texpression env e
+ if e.ty = mk_unit_ty then mk_unit_rvalue
+ else super#visit_texpression env e
(** Replace in "regular" expressions - note that we could limit ourselves
to variables, but this is more powerful
*)
@@ -1026,7 +1027,7 @@ let eliminate_box_functions (_ctx : trans_ctx) (def : fun_decl) : fun_decl =
| A.BoxDeref, Some _ ->
(* `Box::deref` backward is `()` (doesn't give back anything) *)
assert (args = []);
- unit_rvalue
+ mk_unit_rvalue
| A.BoxDerefMut, None ->
(* `Box::deref_mut` forward is the identity *)
let arg, args = Collections.List.pop args in
@@ -1043,7 +1044,7 @@ let eliminate_box_functions (_ctx : trans_ctx) (def : fun_decl) : fun_decl =
mk_apps arg args
| A.BoxFree, _ ->
assert (args = []);
- unit_rvalue
+ mk_unit_rvalue
| ( ( A.Replace | A.VecNew | A.VecPush | A.VecInsert | A.VecLen
| A.VecIndex | A.VecIndexMut ),
_ ) ->
@@ -1215,7 +1216,7 @@ let unfold_monadic_let_bindings (config : config) (_ctx : trans_ctx)
in
(* Create the match *)
let fail_pat = mk_result_fail_pattern re_no_monad_ty in
- let fail_value = mk_result_fail_rvalue e_no_monad_ty in
+ let fail_value = mk_result_fail_texpression e_no_monad_ty in
let fail_branch = { pat = fail_pat; branch = fail_value } in
(* The `Success` branch introduces a fresh state variable *)
let pat_state_var = fresh_state_var () in
@@ -1246,7 +1247,7 @@ let unfold_monadic_let_bindings (config : config) (_ctx : trans_ctx)
let re_ty = Option.get (opt_destruct_result re.ty) in
assert (lv.ty = re_ty);
let fail_pat = mk_result_fail_pattern lv.ty in
- let fail_value = mk_result_fail_rvalue e.ty in
+ let fail_value = mk_result_fail_texpression e.ty in
let fail_branch = { pat = fail_pat; branch = fail_value } in
let success_pat = mk_result_return_pattern lv in
let success_branch = { pat = success_pat; branch = e } in
diff --git a/src/PureUtils.ml b/src/PureUtils.ml
index fe71b3b2..73794a7c 100644
--- a/src/PureUtils.ml
+++ b/src/PureUtils.ml
@@ -323,9 +323,15 @@ let iter_switch_body_branches (f : texpression -> unit) (sb : switch_body) :
| Match branches -> List.iter (fun (b : match_branch) -> f b.branch) branches
let mk_switch (scrut : texpression) (sb : switch_body) : texpression =
- (* TODO: check the type of the scrutinee *)
- let ty = get_switch_body_ty sb in
+ (* Sanity check: the scrutinee has the proper type *)
+ (match sb with
+ | If (_, _) -> assert (scrut.ty = Bool)
+ | Match branches ->
+ List.iter
+ (fun (b : match_branch) -> assert (b.pat.ty = scrut.ty))
+ branches);
(* Sanity check: all the branches have the same type *)
+ let ty = get_switch_body_ty sb in
iter_switch_body_branches (fun e -> assert (e.ty = ty)) sb;
(* Put together *)
let e = Switch (scrut, sb) in
@@ -338,15 +344,13 @@ let mk_switch (scrut : texpression) (sb : switch_body) : texpression =
let mk_simpl_tuple_ty (tys : ty list) : ty =
match tys with [ ty ] -> ty | _ -> Adt (Tuple, tys)
-(** TODO: rename to "mk_..." *)
-let unit_ty : ty = Adt (Tuple, [])
+let mk_unit_ty : ty = Adt (Tuple, [])
-(** TODO: rename to "mk_unit_texpression" *)
-let unit_rvalue : texpression =
+let mk_unit_rvalue : texpression =
let id = AdtCons { adt_id = Tuple; variant_id = None } in
let qualif = { id; type_args = [] } in
let e = Qualif qualif in
- let ty = unit_ty in
+ let ty = mk_unit_ty in
{ e; ty }
let mk_texpression_from_var (v : var) : texpression =
@@ -416,8 +420,7 @@ let mk_state_ty : ty = Adt (Assumed State, [])
let mk_result_ty (ty : ty) : ty = Adt (Assumed Result, [ ty ])
-(* TODO: rename *)
-let mk_result_fail_rvalue (ty : ty) : texpression =
+let mk_result_fail_texpression (ty : ty) : texpression =
let type_args = [ ty ] in
let ty = Adt (Assumed Result, type_args) in
let id =
@@ -429,8 +432,7 @@ let mk_result_fail_rvalue (ty : ty) : texpression =
let cons = { e = cons_e; ty = cons_ty } in
cons
-(* TODO: rename *)
-let mk_result_return_rvalue (v : texpression) : texpression =
+let mk_result_return_texpression (v : texpression) : texpression =
let type_args = [ v.ty ] in
let ty = Adt (Assumed Result, type_args) in
let id =
diff --git a/src/SymbolicToPure.ml b/src/SymbolicToPure.ml
index c0e47ca7..156d5661 100644
--- a/src/SymbolicToPure.ml
+++ b/src/SymbolicToPure.ml
@@ -1019,14 +1019,14 @@ and translate_panic (config : config) (ctx : bs_ctx) : texpression =
if config.use_state_monad then
(* Create the `Fail` value *)
let ret_ty = mk_simpl_tuple_ty [ mk_state_ty; ctx.ret_ty ] in
- let ret_v = mk_result_fail_rvalue ret_ty in
+ let ret_v = mk_result_fail_texpression ret_ty in
(* Add the lambda *)
let _, state_var =
fresh_var (Some ConstStrings.state_basename) mk_state_ty ctx
in
let state_pattern = mk_typed_pattern_from_var state_var None in
mk_abs state_pattern ret_v
- else mk_result_fail_rvalue ctx.ret_ty
+ else mk_result_fail_texpression ctx.ret_ty
and translate_return (config : config) (opt_v : V.typed_value option)
(ctx : bs_ctx) : texpression =
@@ -1051,12 +1051,12 @@ and translate_return (config : config) (opt_v : V.typed_value option)
in
let state_rvalue = mk_texpression_from_var state_var in
let ret_v =
- mk_result_return_rvalue
+ mk_result_return_texpression
(mk_simpl_tuple_texpression [ state_rvalue; v ])
in
let state_var = mk_typed_pattern_from_var state_var None in
mk_abs state_var ret_v
- else mk_result_return_rvalue v
+ else mk_result_return_texpression v
| Some bid ->
(* Backward function *)
(* Sanity check *)
@@ -1075,14 +1075,14 @@ and translate_return (config : config) (opt_v : V.typed_value option)
let state_rvalue = mk_texpression_from_var state_var in
let ret_value = mk_simpl_tuple_texpression field_values in
let ret_value =
- mk_result_return_rvalue
+ mk_result_return_texpression
(mk_simpl_tuple_texpression [ state_rvalue; ret_value ])
in
let state_var = mk_typed_pattern_from_var state_var None in
mk_abs state_var ret_value
else
let ret_value = mk_simpl_tuple_texpression field_values in
- let ret_value = mk_result_return_rvalue ret_value in
+ let ret_value = mk_result_return_texpression ret_value in
ret_value
and translate_function_call (config : config) (call : S.call) (e : S.expression)