diff options
-rw-r--r-- | src/ExtractToFStar.ml | 4 | ||||
-rw-r--r-- | src/PureMicroPasses.ml | 19 | ||||
-rw-r--r-- | src/PureUtils.ml | 24 | ||||
-rw-r--r-- | src/SymbolicToPure.ml | 12 |
4 files changed, 31 insertions, 28 deletions
diff --git a/src/ExtractToFStar.ml b/src/ExtractToFStar.ml index 0a06ef73..068448e9 100644 --- a/src/ExtractToFStar.ml +++ b/src/ExtractToFStar.ml @@ -1475,8 +1475,8 @@ let extract_unit_test_if_unit_fun (ctx : extraction_ctx) (fmt : F.formatter) let sg = def.signature in if sg.type_params = [] - && (sg.inputs = [ unit_ty ] || sg.inputs = []) - && sg.outputs = [ mk_result_ty unit_ty ] + && (sg.inputs = [ mk_unit_ty ] || sg.inputs = []) + && sg.outputs = [ mk_result_ty mk_unit_ty ] then ( (* Add a break before *) F.pp_print_break fmt 0 0; diff --git a/src/PureMicroPasses.ml b/src/PureMicroPasses.ml index bd7d7766..f76dd2f4 100644 --- a/src/PureMicroPasses.ml +++ b/src/PureMicroPasses.ml @@ -923,8 +923,8 @@ let filter_if_backward_with_no_outputs (config : config) (def : fun_decl) : let return_ty = if config.use_state_monad then mk_arrow mk_state_ty - (mk_result_ty (mk_simpl_tuple_ty [ mk_state_ty; unit_ty ])) - else mk_result_ty (mk_simpl_tuple_ty [ unit_ty ]) + (mk_result_ty (mk_simpl_tuple_ty [ mk_state_ty; mk_unit_ty ])) + else mk_result_ty (mk_simpl_tuple_ty [ mk_unit_ty ]) in if config.filter_useless_functions && Option.is_some def.back_id @@ -954,7 +954,7 @@ let keep_forward (config : config) (trans : pure_fun_translation) : bool = * they should be lists of length 1. *) if config.filter_useless_functions - && fwd.signature.outputs = [ mk_result_ty unit_ty ] + && fwd.signature.outputs = [ mk_result_ty mk_unit_ty ] && backs <> [] then false else true @@ -968,11 +968,12 @@ let unit_vars_to_unit (def : fun_decl) : fun_decl = inherit [_] map_expression as super method! visit_PatVar _ v mp = - if v.ty = unit_ty then PatDummy else PatVar (v, mp) + if v.ty = mk_unit_ty then PatDummy else PatVar (v, mp) (** Replace in patterns *) method! visit_texpression env e = - if e.ty = unit_ty then unit_rvalue else super#visit_texpression env e + if e.ty = mk_unit_ty then mk_unit_rvalue + else super#visit_texpression env e (** Replace in "regular" expressions - note that we could limit ourselves to variables, but this is more powerful *) @@ -1026,7 +1027,7 @@ let eliminate_box_functions (_ctx : trans_ctx) (def : fun_decl) : fun_decl = | A.BoxDeref, Some _ -> (* `Box::deref` backward is `()` (doesn't give back anything) *) assert (args = []); - unit_rvalue + mk_unit_rvalue | A.BoxDerefMut, None -> (* `Box::deref_mut` forward is the identity *) let arg, args = Collections.List.pop args in @@ -1043,7 +1044,7 @@ let eliminate_box_functions (_ctx : trans_ctx) (def : fun_decl) : fun_decl = mk_apps arg args | A.BoxFree, _ -> assert (args = []); - unit_rvalue + mk_unit_rvalue | ( ( A.Replace | A.VecNew | A.VecPush | A.VecInsert | A.VecLen | A.VecIndex | A.VecIndexMut ), _ ) -> @@ -1215,7 +1216,7 @@ let unfold_monadic_let_bindings (config : config) (_ctx : trans_ctx) in (* Create the match *) let fail_pat = mk_result_fail_pattern re_no_monad_ty in - let fail_value = mk_result_fail_rvalue e_no_monad_ty in + let fail_value = mk_result_fail_texpression e_no_monad_ty in let fail_branch = { pat = fail_pat; branch = fail_value } in (* The `Success` branch introduces a fresh state variable *) let pat_state_var = fresh_state_var () in @@ -1246,7 +1247,7 @@ let unfold_monadic_let_bindings (config : config) (_ctx : trans_ctx) let re_ty = Option.get (opt_destruct_result re.ty) in assert (lv.ty = re_ty); let fail_pat = mk_result_fail_pattern lv.ty in - let fail_value = mk_result_fail_rvalue e.ty in + let fail_value = mk_result_fail_texpression e.ty in let fail_branch = { pat = fail_pat; branch = fail_value } in let success_pat = mk_result_return_pattern lv in let success_branch = { pat = success_pat; branch = e } in diff --git a/src/PureUtils.ml b/src/PureUtils.ml index fe71b3b2..73794a7c 100644 --- a/src/PureUtils.ml +++ b/src/PureUtils.ml @@ -323,9 +323,15 @@ let iter_switch_body_branches (f : texpression -> unit) (sb : switch_body) : | Match branches -> List.iter (fun (b : match_branch) -> f b.branch) branches let mk_switch (scrut : texpression) (sb : switch_body) : texpression = - (* TODO: check the type of the scrutinee *) - let ty = get_switch_body_ty sb in + (* Sanity check: the scrutinee has the proper type *) + (match sb with + | If (_, _) -> assert (scrut.ty = Bool) + | Match branches -> + List.iter + (fun (b : match_branch) -> assert (b.pat.ty = scrut.ty)) + branches); (* Sanity check: all the branches have the same type *) + let ty = get_switch_body_ty sb in iter_switch_body_branches (fun e -> assert (e.ty = ty)) sb; (* Put together *) let e = Switch (scrut, sb) in @@ -338,15 +344,13 @@ let mk_switch (scrut : texpression) (sb : switch_body) : texpression = let mk_simpl_tuple_ty (tys : ty list) : ty = match tys with [ ty ] -> ty | _ -> Adt (Tuple, tys) -(** TODO: rename to "mk_..." *) -let unit_ty : ty = Adt (Tuple, []) +let mk_unit_ty : ty = Adt (Tuple, []) -(** TODO: rename to "mk_unit_texpression" *) -let unit_rvalue : texpression = +let mk_unit_rvalue : texpression = let id = AdtCons { adt_id = Tuple; variant_id = None } in let qualif = { id; type_args = [] } in let e = Qualif qualif in - let ty = unit_ty in + let ty = mk_unit_ty in { e; ty } let mk_texpression_from_var (v : var) : texpression = @@ -416,8 +420,7 @@ let mk_state_ty : ty = Adt (Assumed State, []) let mk_result_ty (ty : ty) : ty = Adt (Assumed Result, [ ty ]) -(* TODO: rename *) -let mk_result_fail_rvalue (ty : ty) : texpression = +let mk_result_fail_texpression (ty : ty) : texpression = let type_args = [ ty ] in let ty = Adt (Assumed Result, type_args) in let id = @@ -429,8 +432,7 @@ let mk_result_fail_rvalue (ty : ty) : texpression = let cons = { e = cons_e; ty = cons_ty } in cons -(* TODO: rename *) -let mk_result_return_rvalue (v : texpression) : texpression = +let mk_result_return_texpression (v : texpression) : texpression = let type_args = [ v.ty ] in let ty = Adt (Assumed Result, type_args) in let id = diff --git a/src/SymbolicToPure.ml b/src/SymbolicToPure.ml index c0e47ca7..156d5661 100644 --- a/src/SymbolicToPure.ml +++ b/src/SymbolicToPure.ml @@ -1019,14 +1019,14 @@ and translate_panic (config : config) (ctx : bs_ctx) : texpression = if config.use_state_monad then (* Create the `Fail` value *) let ret_ty = mk_simpl_tuple_ty [ mk_state_ty; ctx.ret_ty ] in - let ret_v = mk_result_fail_rvalue ret_ty in + let ret_v = mk_result_fail_texpression ret_ty in (* Add the lambda *) let _, state_var = fresh_var (Some ConstStrings.state_basename) mk_state_ty ctx in let state_pattern = mk_typed_pattern_from_var state_var None in mk_abs state_pattern ret_v - else mk_result_fail_rvalue ctx.ret_ty + else mk_result_fail_texpression ctx.ret_ty and translate_return (config : config) (opt_v : V.typed_value option) (ctx : bs_ctx) : texpression = @@ -1051,12 +1051,12 @@ and translate_return (config : config) (opt_v : V.typed_value option) in let state_rvalue = mk_texpression_from_var state_var in let ret_v = - mk_result_return_rvalue + mk_result_return_texpression (mk_simpl_tuple_texpression [ state_rvalue; v ]) in let state_var = mk_typed_pattern_from_var state_var None in mk_abs state_var ret_v - else mk_result_return_rvalue v + else mk_result_return_texpression v | Some bid -> (* Backward function *) (* Sanity check *) @@ -1075,14 +1075,14 @@ and translate_return (config : config) (opt_v : V.typed_value option) let state_rvalue = mk_texpression_from_var state_var in let ret_value = mk_simpl_tuple_texpression field_values in let ret_value = - mk_result_return_rvalue + mk_result_return_texpression (mk_simpl_tuple_texpression [ state_rvalue; ret_value ]) in let state_var = mk_typed_pattern_from_var state_var None in mk_abs state_var ret_value else let ret_value = mk_simpl_tuple_texpression field_values in - let ret_value = mk_result_return_rvalue ret_value in + let ret_value = mk_result_return_texpression ret_value in ret_value and translate_function_call (config : config) (call : S.call) (e : S.expression) |