diff options
author | Son Ho | 2022-04-27 15:47:39 +0200 |
---|---|---|
committer | Son Ho | 2022-04-27 15:47:39 +0200 |
commit | 018278ff418da62d1391c5f500def96890602f5a (patch) | |
tree | f19b69d79998e1022ad898d4dfc4e319f058f95e /src | |
parent | 003d039b5b51619699e96669007f6d095928251c (diff) |
Fix various bugs when extracting with a state monad
Diffstat (limited to 'src')
-rw-r--r-- | src/ExtractToFStar.ml | 22 | ||||
-rw-r--r-- | src/Pure.ml | 15 | ||||
-rw-r--r-- | src/PureMicroPasses.ml | 78 | ||||
-rw-r--r-- | src/PureUtils.ml | 9 | ||||
-rw-r--r-- | src/SymbolicToPure.ml | 74 |
5 files changed, 146 insertions, 52 deletions
diff --git a/src/ExtractToFStar.ml b/src/ExtractToFStar.ml index 5e65d560..f0ec73f5 100644 --- a/src/ExtractToFStar.ml +++ b/src/ExtractToFStar.ml @@ -942,7 +942,14 @@ and extract_App (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool) F.pp_close_box fmt (); (* Return *) if inside then F.pp_print_string fmt ")" - | _ -> raise (Failure "Unreachable")) + | _ -> + raise + (Failure + ("Unreachable:\n" ^ "Function: " ^ show_fun_id func.func + ^ ",\nNumber of arguments: " + ^ string_of_int (List.length args) + ^ ",\nArguments: " + ^ String.concat " " (List.map show_texpression args)))) | _ -> (* "Regular" expression *) (* Open parentheses *) @@ -950,7 +957,7 @@ and extract_App (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool) (* Open a box for the application *) F.pp_open_hovbox fmt ctx.indent_incr; (* Print the app expression *) - let app_inside = inside && args <> [] in + let app_inside = (inside && args = []) || args <> [] in extract_texpression ctx fmt app_inside app; (* Print the arguments *) List.iter @@ -993,10 +1000,10 @@ and extract_Let (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool) (next_e : texpression) : unit = (* Open a box for the whole expression *) F.pp_open_hvbox fmt 0; - (* Open a box for the let-binding *) - F.pp_open_hovbox fmt ctx.indent_incr; (* Open parentheses *) if inside then F.pp_print_string fmt "("; + (* Open a box for the let-binding *) + F.pp_open_hovbox fmt ctx.indent_incr; let ctx = if monadic then ( (* Note that in F*, the left value of a monadic let-binding can only be @@ -1020,13 +1027,13 @@ and extract_Let (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool) F.pp_print_string fmt "in"; ctx) in - (* Close parentheses *) - if inside then F.pp_print_string fmt ")"; (* Close the box for the let-binding *) F.pp_close_box fmt (); (* Print the next expression *) F.pp_print_space fmt (); extract_texpression ctx fmt false next_e; + (* Close parentheses *) + if inside then F.pp_print_string fmt ")"; (* Close the box for the whole expression *) F.pp_close_box fmt () @@ -1361,7 +1368,8 @@ let extract_fun_decl (ctx : extraction_ctx) (fmt : F.formatter) * Tot (result (state & u32)) (decreases (f_decreases x st)) * ``` * Rk.: if a function has a decreases clause, it is necessarily - * a transparent function *) + * a transparent function + *) let inputs_lvs = let num_fwd_inputs = List.length (Option.get fwd_def.body).inputs_lvs in let num_fwd_inputs = diff --git a/src/Pure.ml b/src/Pure.ml index 6729a43d..cd28b035 100644 --- a/src/Pure.ml +++ b/src/Pure.ml @@ -171,6 +171,8 @@ type mplace = { (* TODO: there shouldn't be places *) type place = { var : VarId.id; projection : projection } [@@deriving show] +type variant_id = VariantId.id [@@deriving show] + (** Ancestor for [iter_var_or_dummy] visitor *) class ['self] iter_value_base = object (_self : 'self) @@ -185,6 +187,8 @@ class ['self] iter_value_base = method visit_mplace : 'env -> mplace -> unit = fun _ _ -> () method visit_ty : 'env -> ty -> unit = fun _ _ -> () + + method visit_variant_id : 'env -> variant_id -> unit = fun _ _ -> () end (** Ancestor for [map_typed_rvalue] visitor *) @@ -202,6 +206,8 @@ class ['self] map_value_base = method visit_mplace : 'env -> mplace -> mplace = fun _ x -> x method visit_ty : 'env -> ty -> ty = fun _ x -> x + + method visit_variant_id : 'env -> variant_id -> variant_id = fun _ x -> x end (** Ancestor for [reduce_typed_rvalue] visitor *) @@ -219,6 +225,8 @@ class virtual ['self] reduce_value_base = method visit_mplace : 'env -> mplace -> 'a = fun _ _ -> self#zero method visit_ty : 'env -> ty -> 'a = fun _ _ -> self#zero + + method visit_variant_id : 'env -> variant_id -> 'a = fun _ _ -> self#zero end (** Ancestor for [mapreduce_typed_rvalue] visitor *) @@ -238,6 +246,9 @@ class virtual ['self] mapreduce_value_base = fun _ x -> (x, self#zero) method visit_ty : 'env -> ty -> ty * 'a = fun _ x -> (x, self#zero) + + method visit_variant_id : 'env -> variant_id -> variant_id * 'a = + fun _ x -> (x, self#zero) end (* TODO: merge with expressions *) @@ -247,7 +258,7 @@ type rvalue = | RvAdt of adt_rvalue and adt_rvalue = { - variant_id : (VariantId.id option[@opaque]); + variant_id : variant_id option; (* TODO: variant constructors should be expressions, treated in a manner * similar to functions *) field_values : typed_rvalue list; @@ -348,7 +359,7 @@ type lvalue = | LvAdt of adt_lvalue and adt_lvalue = { - variant_id : (VariantId.id option[@opaque]); + variant_id : variant_id option; field_values : typed_lvalue list; } diff --git a/src/PureMicroPasses.ml b/src/PureMicroPasses.ml index 2c4c667f..e22043e3 100644 --- a/src/PureMicroPasses.ml +++ b/src/PureMicroPasses.ml @@ -961,7 +961,8 @@ let filter_if_backward_with_no_outputs (config : config) (def : fun_decl) : fun_decl option = let return_ty = if config.use_state_monad then - mk_result_ty (mk_simpl_tuple_ty [ mk_state_ty; unit_ty ]) + mk_arrow mk_state_ty + (mk_result_ty (mk_simpl_tuple_ty [ mk_state_ty; unit_ty ])) else mk_result_ty (mk_simpl_tuple_ty [ unit_ty ]) in if @@ -1146,7 +1147,7 @@ let decompose_monadic_let_bindings (_ctx : trans_ctx) (def : fun_decl) : { def with body } (** Unfold the monadic let-bindings to explicit matches. *) -let unfold_monadic_let_bindings (_config : config) (_ctx : trans_ctx) +let unfold_monadic_let_bindings (config : config) (_ctx : trans_ctx) (def : fun_decl) : fun_decl = match def.body with | None -> def @@ -1169,13 +1170,13 @@ let unfold_monadic_let_bindings (_config : config) (_ctx : trans_ctx) method! visit_Let env monadic lv re e = (* For now, we do the following transformation: * ``` - * x <-- e1; e2 + * x <-- re; e * * ~~> * * (fun st -> - * match e1 st with - * | Return (st', x) -> e2 st' + * match re st with + * | Return (st', x) -> e st' * | Fail err -> Fail err) * ``` * @@ -1204,8 +1205,14 @@ let unfold_monadic_let_bindings (_config : config) (_ctx : trans_ctx) Option.is_some (opt_destruct_state_monad_result re.ty) in if re_uses_state then ( + let e0 = e in (* Create a fresh state variable *) let state_var = fresh_state_var () in + (* The type of `e` is: `state -> e_no_arrow_ty` *) + let _, e_no_arrow_ty = destruct_arrow e.ty in + let e_no_monad_ty = destruct_result e_no_arrow_ty in + let _, re_no_arrow_ty = destruct_arrow re.ty in + let re_no_monad_ty = destruct_result re_no_arrow_ty in (* Add the state argument on the right-expression *) let re = let state_value = mk_typed_rvalue_from_var state_var in @@ -1213,8 +1220,8 @@ let unfold_monadic_let_bindings (_config : config) (_ctx : trans_ctx) mk_app re state_value in (* Create the match *) - let fail_pat = mk_result_fail_lvalue lv.ty in - let fail_value = mk_result_fail_rvalue e.ty in + let fail_pat = mk_result_fail_lvalue re_no_monad_ty in + let fail_value = mk_result_fail_rvalue e_no_monad_ty in let fail_branch = { pat = fail_pat; @@ -1222,23 +1229,31 @@ let unfold_monadic_let_bindings (_config : config) (_ctx : trans_ctx) } in (* The `Success` branch introduces a fresh state variable *) - let state_var = fresh_state_var () in - let state_value = mk_typed_lvalue_from_var state_var None in + let pat_state_var = fresh_state_var () in + let pat_state_lvalue = + mk_typed_lvalue_from_var pat_state_var None + in let success_pat = mk_result_return_lvalue - (mk_simpl_tuple_lvalue [ state_value; lv ]) + (mk_simpl_tuple_lvalue [ pat_state_lvalue; lv ]) + in + let pat_state_rvalue = mk_typed_rvalue_from_var pat_state_var in + let pat_state_rvalue = + mk_value_expression pat_state_rvalue None in (* TODO: write a utility to create matches (and perform * type-checking, etc.) *) - let ty = e.ty in - let success_branch = { pat = success_pat; branch = e } in + let success_branch = + { pat = success_pat; branch = mk_app e pat_state_rvalue } + in let switch_body = Match [ fail_branch; success_branch ] in let e = Switch (re, switch_body) in - let e = { e; ty } in - (* Sanity check *) - assert (ty = fail_value.ty); + let e = { e; ty = e_no_arrow_ty } in (* Add the lambda to introduce the state variable *) let e = mk_abs (mk_typed_lvalue_from_var state_var None) e in + (* Sanity check *) + assert (e0.ty = e.ty); + assert (fail_branch.branch.ty = success_branch.branch.ty); (* Continue *) self#visit_expression env e.e) else @@ -1256,14 +1271,39 @@ let unfold_monadic_let_bindings (_config : config) (_ctx : trans_ctx) let success_branch = { pat = success_pat; branch = e } in let switch_body = Match [ fail_branch; success_branch ] in let e = Switch (re, switch_body) in + (* Continue *) self#visit_expression env e end in - (* Update the body *) - let body_e = obj#visit_texpression () body.body in - let body = { body with body = body_e } in + (* Update the body: add *) + let body, signature = + let state_var = fresh_state_var () in + (* First, unfold the expressions inside the body *) + let body_e = obj#visit_texpression () body.body in + (* Then, add a "state" input variable if necessary: *) + if config.use_state_monad then + (* - in the body *) + let state_rvalue = mk_typed_rvalue_from_var state_var in + let body_e = mk_app body_e (mk_value_expression state_rvalue None) in + (* - in the signature *) + let sg = def.signature in + (* Input types *) + let sg_inputs = sg.inputs @ [ mk_state_ty ] in + (* Output types *) + let sg_outputs = Collections.List.to_cons_nil sg.outputs in + let _, sg_outputs = dest_arrow_ty sg_outputs in + let sg_outputs = [ sg_outputs ] in + let sg = { sg with inputs = sg_inputs; outputs = sg_outputs } in + (* Input list *) + let inputs = body.inputs @ [ state_var ] in + let input_lv = mk_typed_lvalue_from_var state_var None in + let inputs_lvs = body.inputs_lvs @ [ input_lv ] in + let body = { body = body_e; inputs; inputs_lvs } in + (body, sg) + else ({ body with body = body_e }, def.signature) + in (* Return *) - { def with body = Some body } + { def with body = Some body; signature } (** Apply all the micro-passes to a function. diff --git a/src/PureUtils.ml b/src/PureUtils.ml index bcf93c3c..b87a6346 100644 --- a/src/PureUtils.ml +++ b/src/PureUtils.ml @@ -435,6 +435,8 @@ let opt_destruct_result (ty : ty) : ty option = | Adt (Assumed Result, tys) -> Some (Collections.List.to_cons_nil tys) | _ -> None +let destruct_result (ty : ty) : ty = Option.get (opt_destruct_result ty) + let opt_destruct_tuple (ty : ty) : ty list option = match ty with Adt (Tuple, tys) -> Some tys | _ -> None @@ -469,3 +471,10 @@ let rec destruct_abs_list (e : texpression) : typed_lvalue list * texpression = let xl, e'' = destruct_abs_list e' in (x :: xl, e'') | _ -> ([], e) + +let destruct_arrow (ty : ty) : ty * ty = + match ty with + | Arrow (ty0, ty1) -> (ty0, ty1) + | _ -> raise (Failure "Unreachable") + +let mk_arrow (ty0 : ty) (ty1 : ty) : ty = Arrow (ty0, ty1) diff --git a/src/SymbolicToPure.ml b/src/SymbolicToPure.ml index 4e15d921..49bf3559 100644 --- a/src/SymbolicToPure.ml +++ b/src/SymbolicToPure.ml @@ -948,17 +948,26 @@ let get_abs_ancestors (ctx : bs_ctx) (abs : V.abs) : S.call * V.abs list = (** Small utility. - Return true if a function return type is monadic. - Always true, at the exception of some assumed functions. + Return: (function is monadic, function uses state monad) + + Note that all functions are monadic except some assumed functions. *) -let fun_is_monadic (fun_id : A.fun_id) : bool = +let fun_is_monadic (fun_id : A.fun_id) : bool * bool = match fun_id with - | A.Regular _ -> true - | A.Assumed aid -> Assumed.assumed_is_monadic aid + | A.Regular _ -> (true, true) + | A.Assumed aid -> (Assumed.assumed_is_monadic aid, false) + +(** Utility for function return types. -let mk_function_ret_ty (config : config) (monadic : bool) (out_ty : ty) : ty = + A function return type can have the shape: + - ty + - result ty (* error-monad *) + - state -> result (state & ty) (* state-error monad *) + *) +let mk_function_ret_ty (config : config) (monadic : bool) (state_monad : bool) + (out_ty : ty) : ty = if monadic then - if config.use_state_monad then + if config.use_state_monad && state_monad then let ret = mk_result_ty (mk_simpl_tuple_ty [ mk_state_ty; out_ty ]) in let ret = mk_arrow_ty mk_state_ty ret in ret @@ -969,19 +978,34 @@ let rec translate_expression (config : config) (e : S.expression) (ctx : bs_ctx) : texpression = match e with | S.Return opt_v -> translate_return config opt_v ctx - | Panic -> - (* Here we use the function return type - note that it is ok because - * we don't match on panics which happen inside the function body - - * but it won't be true anymore once we translate individual blocks *) - let v = mk_result_fail_rvalue ctx.ret_ty in - let e = Value (v, None) in - let ty = v.ty in - { e; ty } + | Panic -> translate_panic config ctx | FunCall (call, e) -> translate_function_call config call e ctx | EndAbstraction (abs, e) -> translate_end_abstraction config abs e ctx | Expansion (p, sv, exp) -> translate_expansion config p sv exp ctx | Meta (meta, e) -> translate_meta config meta e ctx +and translate_panic (config : config) (ctx : bs_ctx) : texpression = + (* Here we use the function return type - note that it is ok because + * we don't match on panics which happen inside the function body - + * but it won't be true anymore once we translate individual blocks *) + (* If we use a state monad, we need to add a lambda for the state variable *) + if config.use_state_monad then + (* Create the `Fail` value *) + let ret_ty = mk_simpl_tuple_ty [ mk_state_ty; ctx.ret_ty ] in + let v = mk_result_fail_rvalue ret_ty in + let e = Value (v, None) in + let ty = v.ty in + let e = { e; ty } in + (* Add the lambda *) + let _, state_var = fresh_var (Some "st") mk_state_ty ctx in + let state_lvalue = mk_typed_lvalue_from_var state_var None in + mk_abs state_lvalue e + else + let v = mk_result_fail_rvalue ctx.ret_ty in + let e = Value (v, None) in + let ty = v.ty in + { e; ty } + and translate_return (config : config) (opt_v : V.typed_value option) (ctx : bs_ctx) : texpression = (* There are two cases: @@ -1058,21 +1082,21 @@ and translate_function_call (config : config) (call : S.call) (e : S.expression) let ctx, dest = fresh_var_for_symbolic_value call.dest ctx in (* Retrieve the function id, and register the function call in the context * if necessary. *) - let ctx, func, monadic = + let ctx, func, monadic, state_monad = match call.call_id with | S.Fun (fid, call_id) -> let ctx = bs_ctx_register_forward_call call_id call ctx in let func = Regular (fid, None) in - let monadic = fun_is_monadic fid in - (ctx, func, monadic) - | S.Unop E.Not -> (ctx, Unop Not, false) + let monadic, state_monad = fun_is_monadic fid in + (ctx, func, monadic, state_monad) + | S.Unop E.Not -> (ctx, Unop Not, false, false) | S.Unop E.Neg -> ( match args with | [ arg ] -> let int_ty = ty_as_integer arg.ty in (* Note that negation can lead to an overflow and thus fail (it * is thus monadic) *) - (ctx, Unop (Neg int_ty), true) + (ctx, Unop (Neg int_ty), true, false) | _ -> failwith "Unreachable") | S.Binop binop -> ( match args with @@ -1081,7 +1105,7 @@ and translate_function_call (config : config) (call : S.call) (e : S.expression) let int_ty1 = ty_as_integer arg1.ty in assert (int_ty0 = int_ty1); let monadic = binop_can_fail binop in - (ctx, Binop (binop, int_ty0), monadic) + (ctx, Binop (binop, int_ty0), monadic, false) | _ -> failwith "Unreachable") in let args = @@ -1092,7 +1116,7 @@ and translate_function_call (config : config) (call : S.call) (e : S.expression) let dest_v = mk_typed_lvalue_from_var dest dest_mplace in let func = { func; type_params } in let input_tys = (List.map (fun (x : texpression) -> x.ty)) args in - let ret_ty = mk_function_ret_ty config monadic dest_v.ty in + let ret_ty = mk_function_ret_ty config monadic state_monad dest_v.ty in let func_ty = mk_arrows input_tys ret_ty in let func = { e = Func func; ty = func_ty } in let call = mk_apps func args in @@ -1223,9 +1247,9 @@ and translate_end_abstraction (config : config) (abs : V.abs) (e : S.expression) (fun (arg, mp) -> mk_value_expression arg mp) (List.combine inputs args_mplaces) in - let monadic = fun_is_monadic fun_id in + let monadic, state_monad = fun_is_monadic fun_id in let input_tys = (List.map (fun (x : texpression) -> x.ty)) args in - let ret_ty = mk_function_ret_ty config monadic output.ty in + let ret_ty = mk_function_ret_ty config monadic state_monad output.ty in let func_ty = mk_arrows input_tys ret_ty in let func = { func; type_params } in let func = { e = Func func; ty = func_ty } in @@ -1444,7 +1468,9 @@ and translate_expansion (config : config) (p : S.mplace option) (* There should be at least one branch *) let branch = List.hd branches in let ty = branch.branch.ty in + (* Sanity check *) assert (List.for_all (fun br -> br.branch.ty = ty) branches); + (* Return *) { e; ty }) | ExpandBool (true_e, false_e) -> (* We don't need to update the context: we don't introduce any |