summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorSon Ho2022-04-27 15:47:39 +0200
committerSon Ho2022-04-27 15:47:39 +0200
commit018278ff418da62d1391c5f500def96890602f5a (patch)
treef19b69d79998e1022ad898d4dfc4e319f058f95e /src
parent003d039b5b51619699e96669007f6d095928251c (diff)
Fix various bugs when extracting with a state monad
Diffstat (limited to 'src')
-rw-r--r--src/ExtractToFStar.ml22
-rw-r--r--src/Pure.ml15
-rw-r--r--src/PureMicroPasses.ml78
-rw-r--r--src/PureUtils.ml9
-rw-r--r--src/SymbolicToPure.ml74
5 files changed, 146 insertions, 52 deletions
diff --git a/src/ExtractToFStar.ml b/src/ExtractToFStar.ml
index 5e65d560..f0ec73f5 100644
--- a/src/ExtractToFStar.ml
+++ b/src/ExtractToFStar.ml
@@ -942,7 +942,14 @@ and extract_App (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool)
F.pp_close_box fmt ();
(* Return *)
if inside then F.pp_print_string fmt ")"
- | _ -> raise (Failure "Unreachable"))
+ | _ ->
+ raise
+ (Failure
+ ("Unreachable:\n" ^ "Function: " ^ show_fun_id func.func
+ ^ ",\nNumber of arguments: "
+ ^ string_of_int (List.length args)
+ ^ ",\nArguments: "
+ ^ String.concat " " (List.map show_texpression args))))
| _ ->
(* "Regular" expression *)
(* Open parentheses *)
@@ -950,7 +957,7 @@ and extract_App (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool)
(* Open a box for the application *)
F.pp_open_hovbox fmt ctx.indent_incr;
(* Print the app expression *)
- let app_inside = inside && args <> [] in
+ let app_inside = (inside && args = []) || args <> [] in
extract_texpression ctx fmt app_inside app;
(* Print the arguments *)
List.iter
@@ -993,10 +1000,10 @@ and extract_Let (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool)
(next_e : texpression) : unit =
(* Open a box for the whole expression *)
F.pp_open_hvbox fmt 0;
- (* Open a box for the let-binding *)
- F.pp_open_hovbox fmt ctx.indent_incr;
(* Open parentheses *)
if inside then F.pp_print_string fmt "(";
+ (* Open a box for the let-binding *)
+ F.pp_open_hovbox fmt ctx.indent_incr;
let ctx =
if monadic then (
(* Note that in F*, the left value of a monadic let-binding can only be
@@ -1020,13 +1027,13 @@ and extract_Let (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool)
F.pp_print_string fmt "in";
ctx)
in
- (* Close parentheses *)
- if inside then F.pp_print_string fmt ")";
(* Close the box for the let-binding *)
F.pp_close_box fmt ();
(* Print the next expression *)
F.pp_print_space fmt ();
extract_texpression ctx fmt false next_e;
+ (* Close parentheses *)
+ if inside then F.pp_print_string fmt ")";
(* Close the box for the whole expression *)
F.pp_close_box fmt ()
@@ -1361,7 +1368,8 @@ let extract_fun_decl (ctx : extraction_ctx) (fmt : F.formatter)
* Tot (result (state & u32)) (decreases (f_decreases x st))
* ```
* Rk.: if a function has a decreases clause, it is necessarily
- * a transparent function *)
+ * a transparent function
+ *)
let inputs_lvs =
let num_fwd_inputs = List.length (Option.get fwd_def.body).inputs_lvs in
let num_fwd_inputs =
diff --git a/src/Pure.ml b/src/Pure.ml
index 6729a43d..cd28b035 100644
--- a/src/Pure.ml
+++ b/src/Pure.ml
@@ -171,6 +171,8 @@ type mplace = {
(* TODO: there shouldn't be places *)
type place = { var : VarId.id; projection : projection } [@@deriving show]
+type variant_id = VariantId.id [@@deriving show]
+
(** Ancestor for [iter_var_or_dummy] visitor *)
class ['self] iter_value_base =
object (_self : 'self)
@@ -185,6 +187,8 @@ class ['self] iter_value_base =
method visit_mplace : 'env -> mplace -> unit = fun _ _ -> ()
method visit_ty : 'env -> ty -> unit = fun _ _ -> ()
+
+ method visit_variant_id : 'env -> variant_id -> unit = fun _ _ -> ()
end
(** Ancestor for [map_typed_rvalue] visitor *)
@@ -202,6 +206,8 @@ class ['self] map_value_base =
method visit_mplace : 'env -> mplace -> mplace = fun _ x -> x
method visit_ty : 'env -> ty -> ty = fun _ x -> x
+
+ method visit_variant_id : 'env -> variant_id -> variant_id = fun _ x -> x
end
(** Ancestor for [reduce_typed_rvalue] visitor *)
@@ -219,6 +225,8 @@ class virtual ['self] reduce_value_base =
method visit_mplace : 'env -> mplace -> 'a = fun _ _ -> self#zero
method visit_ty : 'env -> ty -> 'a = fun _ _ -> self#zero
+
+ method visit_variant_id : 'env -> variant_id -> 'a = fun _ _ -> self#zero
end
(** Ancestor for [mapreduce_typed_rvalue] visitor *)
@@ -238,6 +246,9 @@ class virtual ['self] mapreduce_value_base =
fun _ x -> (x, self#zero)
method visit_ty : 'env -> ty -> ty * 'a = fun _ x -> (x, self#zero)
+
+ method visit_variant_id : 'env -> variant_id -> variant_id * 'a =
+ fun _ x -> (x, self#zero)
end
(* TODO: merge with expressions *)
@@ -247,7 +258,7 @@ type rvalue =
| RvAdt of adt_rvalue
and adt_rvalue = {
- variant_id : (VariantId.id option[@opaque]);
+ variant_id : variant_id option;
(* TODO: variant constructors should be expressions, treated in a manner
* similar to functions *)
field_values : typed_rvalue list;
@@ -348,7 +359,7 @@ type lvalue =
| LvAdt of adt_lvalue
and adt_lvalue = {
- variant_id : (VariantId.id option[@opaque]);
+ variant_id : variant_id option;
field_values : typed_lvalue list;
}
diff --git a/src/PureMicroPasses.ml b/src/PureMicroPasses.ml
index 2c4c667f..e22043e3 100644
--- a/src/PureMicroPasses.ml
+++ b/src/PureMicroPasses.ml
@@ -961,7 +961,8 @@ let filter_if_backward_with_no_outputs (config : config) (def : fun_decl) :
fun_decl option =
let return_ty =
if config.use_state_monad then
- mk_result_ty (mk_simpl_tuple_ty [ mk_state_ty; unit_ty ])
+ mk_arrow mk_state_ty
+ (mk_result_ty (mk_simpl_tuple_ty [ mk_state_ty; unit_ty ]))
else mk_result_ty (mk_simpl_tuple_ty [ unit_ty ])
in
if
@@ -1146,7 +1147,7 @@ let decompose_monadic_let_bindings (_ctx : trans_ctx) (def : fun_decl) :
{ def with body }
(** Unfold the monadic let-bindings to explicit matches. *)
-let unfold_monadic_let_bindings (_config : config) (_ctx : trans_ctx)
+let unfold_monadic_let_bindings (config : config) (_ctx : trans_ctx)
(def : fun_decl) : fun_decl =
match def.body with
| None -> def
@@ -1169,13 +1170,13 @@ let unfold_monadic_let_bindings (_config : config) (_ctx : trans_ctx)
method! visit_Let env monadic lv re e =
(* For now, we do the following transformation:
* ```
- * x <-- e1; e2
+ * x <-- re; e
*
* ~~>
*
* (fun st ->
- * match e1 st with
- * | Return (st', x) -> e2 st'
+ * match re st with
+ * | Return (st', x) -> e st'
* | Fail err -> Fail err)
* ```
*
@@ -1204,8 +1205,14 @@ let unfold_monadic_let_bindings (_config : config) (_ctx : trans_ctx)
Option.is_some (opt_destruct_state_monad_result re.ty)
in
if re_uses_state then (
+ let e0 = e in
(* Create a fresh state variable *)
let state_var = fresh_state_var () in
+ (* The type of `e` is: `state -> e_no_arrow_ty` *)
+ let _, e_no_arrow_ty = destruct_arrow e.ty in
+ let e_no_monad_ty = destruct_result e_no_arrow_ty in
+ let _, re_no_arrow_ty = destruct_arrow re.ty in
+ let re_no_monad_ty = destruct_result re_no_arrow_ty in
(* Add the state argument on the right-expression *)
let re =
let state_value = mk_typed_rvalue_from_var state_var in
@@ -1213,8 +1220,8 @@ let unfold_monadic_let_bindings (_config : config) (_ctx : trans_ctx)
mk_app re state_value
in
(* Create the match *)
- let fail_pat = mk_result_fail_lvalue lv.ty in
- let fail_value = mk_result_fail_rvalue e.ty in
+ let fail_pat = mk_result_fail_lvalue re_no_monad_ty in
+ let fail_value = mk_result_fail_rvalue e_no_monad_ty in
let fail_branch =
{
pat = fail_pat;
@@ -1222,23 +1229,31 @@ let unfold_monadic_let_bindings (_config : config) (_ctx : trans_ctx)
}
in
(* The `Success` branch introduces a fresh state variable *)
- let state_var = fresh_state_var () in
- let state_value = mk_typed_lvalue_from_var state_var None in
+ let pat_state_var = fresh_state_var () in
+ let pat_state_lvalue =
+ mk_typed_lvalue_from_var pat_state_var None
+ in
let success_pat =
mk_result_return_lvalue
- (mk_simpl_tuple_lvalue [ state_value; lv ])
+ (mk_simpl_tuple_lvalue [ pat_state_lvalue; lv ])
+ in
+ let pat_state_rvalue = mk_typed_rvalue_from_var pat_state_var in
+ let pat_state_rvalue =
+ mk_value_expression pat_state_rvalue None
in
(* TODO: write a utility to create matches (and perform
* type-checking, etc.) *)
- let ty = e.ty in
- let success_branch = { pat = success_pat; branch = e } in
+ let success_branch =
+ { pat = success_pat; branch = mk_app e pat_state_rvalue }
+ in
let switch_body = Match [ fail_branch; success_branch ] in
let e = Switch (re, switch_body) in
- let e = { e; ty } in
- (* Sanity check *)
- assert (ty = fail_value.ty);
+ let e = { e; ty = e_no_arrow_ty } in
(* Add the lambda to introduce the state variable *)
let e = mk_abs (mk_typed_lvalue_from_var state_var None) e in
+ (* Sanity check *)
+ assert (e0.ty = e.ty);
+ assert (fail_branch.branch.ty = success_branch.branch.ty);
(* Continue *)
self#visit_expression env e.e)
else
@@ -1256,14 +1271,39 @@ let unfold_monadic_let_bindings (_config : config) (_ctx : trans_ctx)
let success_branch = { pat = success_pat; branch = e } in
let switch_body = Match [ fail_branch; success_branch ] in
let e = Switch (re, switch_body) in
+ (* Continue *)
self#visit_expression env e
end
in
- (* Update the body *)
- let body_e = obj#visit_texpression () body.body in
- let body = { body with body = body_e } in
+ (* Update the body: add *)
+ let body, signature =
+ let state_var = fresh_state_var () in
+ (* First, unfold the expressions inside the body *)
+ let body_e = obj#visit_texpression () body.body in
+ (* Then, add a "state" input variable if necessary: *)
+ if config.use_state_monad then
+ (* - in the body *)
+ let state_rvalue = mk_typed_rvalue_from_var state_var in
+ let body_e = mk_app body_e (mk_value_expression state_rvalue None) in
+ (* - in the signature *)
+ let sg = def.signature in
+ (* Input types *)
+ let sg_inputs = sg.inputs @ [ mk_state_ty ] in
+ (* Output types *)
+ let sg_outputs = Collections.List.to_cons_nil sg.outputs in
+ let _, sg_outputs = dest_arrow_ty sg_outputs in
+ let sg_outputs = [ sg_outputs ] in
+ let sg = { sg with inputs = sg_inputs; outputs = sg_outputs } in
+ (* Input list *)
+ let inputs = body.inputs @ [ state_var ] in
+ let input_lv = mk_typed_lvalue_from_var state_var None in
+ let inputs_lvs = body.inputs_lvs @ [ input_lv ] in
+ let body = { body = body_e; inputs; inputs_lvs } in
+ (body, sg)
+ else ({ body with body = body_e }, def.signature)
+ in
(* Return *)
- { def with body = Some body }
+ { def with body = Some body; signature }
(** Apply all the micro-passes to a function.
diff --git a/src/PureUtils.ml b/src/PureUtils.ml
index bcf93c3c..b87a6346 100644
--- a/src/PureUtils.ml
+++ b/src/PureUtils.ml
@@ -435,6 +435,8 @@ let opt_destruct_result (ty : ty) : ty option =
| Adt (Assumed Result, tys) -> Some (Collections.List.to_cons_nil tys)
| _ -> None
+let destruct_result (ty : ty) : ty = Option.get (opt_destruct_result ty)
+
let opt_destruct_tuple (ty : ty) : ty list option =
match ty with Adt (Tuple, tys) -> Some tys | _ -> None
@@ -469,3 +471,10 @@ let rec destruct_abs_list (e : texpression) : typed_lvalue list * texpression =
let xl, e'' = destruct_abs_list e' in
(x :: xl, e'')
| _ -> ([], e)
+
+let destruct_arrow (ty : ty) : ty * ty =
+ match ty with
+ | Arrow (ty0, ty1) -> (ty0, ty1)
+ | _ -> raise (Failure "Unreachable")
+
+let mk_arrow (ty0 : ty) (ty1 : ty) : ty = Arrow (ty0, ty1)
diff --git a/src/SymbolicToPure.ml b/src/SymbolicToPure.ml
index 4e15d921..49bf3559 100644
--- a/src/SymbolicToPure.ml
+++ b/src/SymbolicToPure.ml
@@ -948,17 +948,26 @@ let get_abs_ancestors (ctx : bs_ctx) (abs : V.abs) : S.call * V.abs list =
(** Small utility.
- Return true if a function return type is monadic.
- Always true, at the exception of some assumed functions.
+ Return: (function is monadic, function uses state monad)
+
+ Note that all functions are monadic except some assumed functions.
*)
-let fun_is_monadic (fun_id : A.fun_id) : bool =
+let fun_is_monadic (fun_id : A.fun_id) : bool * bool =
match fun_id with
- | A.Regular _ -> true
- | A.Assumed aid -> Assumed.assumed_is_monadic aid
+ | A.Regular _ -> (true, true)
+ | A.Assumed aid -> (Assumed.assumed_is_monadic aid, false)
+
+(** Utility for function return types.
-let mk_function_ret_ty (config : config) (monadic : bool) (out_ty : ty) : ty =
+ A function return type can have the shape:
+ - ty
+ - result ty (* error-monad *)
+ - state -> result (state & ty) (* state-error monad *)
+ *)
+let mk_function_ret_ty (config : config) (monadic : bool) (state_monad : bool)
+ (out_ty : ty) : ty =
if monadic then
- if config.use_state_monad then
+ if config.use_state_monad && state_monad then
let ret = mk_result_ty (mk_simpl_tuple_ty [ mk_state_ty; out_ty ]) in
let ret = mk_arrow_ty mk_state_ty ret in
ret
@@ -969,19 +978,34 @@ let rec translate_expression (config : config) (e : S.expression) (ctx : bs_ctx)
: texpression =
match e with
| S.Return opt_v -> translate_return config opt_v ctx
- | Panic ->
- (* Here we use the function return type - note that it is ok because
- * we don't match on panics which happen inside the function body -
- * but it won't be true anymore once we translate individual blocks *)
- let v = mk_result_fail_rvalue ctx.ret_ty in
- let e = Value (v, None) in
- let ty = v.ty in
- { e; ty }
+ | Panic -> translate_panic config ctx
| FunCall (call, e) -> translate_function_call config call e ctx
| EndAbstraction (abs, e) -> translate_end_abstraction config abs e ctx
| Expansion (p, sv, exp) -> translate_expansion config p sv exp ctx
| Meta (meta, e) -> translate_meta config meta e ctx
+and translate_panic (config : config) (ctx : bs_ctx) : texpression =
+ (* Here we use the function return type - note that it is ok because
+ * we don't match on panics which happen inside the function body -
+ * but it won't be true anymore once we translate individual blocks *)
+ (* If we use a state monad, we need to add a lambda for the state variable *)
+ if config.use_state_monad then
+ (* Create the `Fail` value *)
+ let ret_ty = mk_simpl_tuple_ty [ mk_state_ty; ctx.ret_ty ] in
+ let v = mk_result_fail_rvalue ret_ty in
+ let e = Value (v, None) in
+ let ty = v.ty in
+ let e = { e; ty } in
+ (* Add the lambda *)
+ let _, state_var = fresh_var (Some "st") mk_state_ty ctx in
+ let state_lvalue = mk_typed_lvalue_from_var state_var None in
+ mk_abs state_lvalue e
+ else
+ let v = mk_result_fail_rvalue ctx.ret_ty in
+ let e = Value (v, None) in
+ let ty = v.ty in
+ { e; ty }
+
and translate_return (config : config) (opt_v : V.typed_value option)
(ctx : bs_ctx) : texpression =
(* There are two cases:
@@ -1058,21 +1082,21 @@ and translate_function_call (config : config) (call : S.call) (e : S.expression)
let ctx, dest = fresh_var_for_symbolic_value call.dest ctx in
(* Retrieve the function id, and register the function call in the context
* if necessary. *)
- let ctx, func, monadic =
+ let ctx, func, monadic, state_monad =
match call.call_id with
| S.Fun (fid, call_id) ->
let ctx = bs_ctx_register_forward_call call_id call ctx in
let func = Regular (fid, None) in
- let monadic = fun_is_monadic fid in
- (ctx, func, monadic)
- | S.Unop E.Not -> (ctx, Unop Not, false)
+ let monadic, state_monad = fun_is_monadic fid in
+ (ctx, func, monadic, state_monad)
+ | S.Unop E.Not -> (ctx, Unop Not, false, false)
| S.Unop E.Neg -> (
match args with
| [ arg ] ->
let int_ty = ty_as_integer arg.ty in
(* Note that negation can lead to an overflow and thus fail (it
* is thus monadic) *)
- (ctx, Unop (Neg int_ty), true)
+ (ctx, Unop (Neg int_ty), true, false)
| _ -> failwith "Unreachable")
| S.Binop binop -> (
match args with
@@ -1081,7 +1105,7 @@ and translate_function_call (config : config) (call : S.call) (e : S.expression)
let int_ty1 = ty_as_integer arg1.ty in
assert (int_ty0 = int_ty1);
let monadic = binop_can_fail binop in
- (ctx, Binop (binop, int_ty0), monadic)
+ (ctx, Binop (binop, int_ty0), monadic, false)
| _ -> failwith "Unreachable")
in
let args =
@@ -1092,7 +1116,7 @@ and translate_function_call (config : config) (call : S.call) (e : S.expression)
let dest_v = mk_typed_lvalue_from_var dest dest_mplace in
let func = { func; type_params } in
let input_tys = (List.map (fun (x : texpression) -> x.ty)) args in
- let ret_ty = mk_function_ret_ty config monadic dest_v.ty in
+ let ret_ty = mk_function_ret_ty config monadic state_monad dest_v.ty in
let func_ty = mk_arrows input_tys ret_ty in
let func = { e = Func func; ty = func_ty } in
let call = mk_apps func args in
@@ -1223,9 +1247,9 @@ and translate_end_abstraction (config : config) (abs : V.abs) (e : S.expression)
(fun (arg, mp) -> mk_value_expression arg mp)
(List.combine inputs args_mplaces)
in
- let monadic = fun_is_monadic fun_id in
+ let monadic, state_monad = fun_is_monadic fun_id in
let input_tys = (List.map (fun (x : texpression) -> x.ty)) args in
- let ret_ty = mk_function_ret_ty config monadic output.ty in
+ let ret_ty = mk_function_ret_ty config monadic state_monad output.ty in
let func_ty = mk_arrows input_tys ret_ty in
let func = { func; type_params } in
let func = { e = Func func; ty = func_ty } in
@@ -1444,7 +1468,9 @@ and translate_expansion (config : config) (p : S.mplace option)
(* There should be at least one branch *)
let branch = List.hd branches in
let ty = branch.branch.ty in
+ (* Sanity check *)
assert (List.for_all (fun br -> br.branch.ty = ty) branches);
+ (* Return *)
{ e; ty })
| ExpandBool (true_e, false_e) ->
(* We don't need to update the context: we don't introduce any