summaryrefslogtreecommitdiff
path: root/compiler/SymbolicToPure.ml
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/SymbolicToPure.ml')
-rw-r--r--compiler/SymbolicToPure.ml169
1 files changed, 112 insertions, 57 deletions
diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml
index 456ec0f6..d62cc829 100644
--- a/compiler/SymbolicToPure.ml
+++ b/compiler/SymbolicToPure.ml
@@ -127,7 +127,15 @@ type bs_ctx = {
trait_decls_ctx : trait_decls_context;
trait_impls_ctx : trait_impls_context;
fun_decl : A.fun_decl;
- bid : T.RegionGroupId.id option; (** TODO: rename *)
+ bid : RegionGroupId.id option;
+ (** TODO: rename
+
+ The id of the group region we are currently translating.
+ If we split the forward/backward functions, we set this id at the
+ very beginning of the translation.
+ If we don't split, we set it to `None`, then update it when we enter
+ an expression which is specific to a backward function.
+ *)
sg : decomposed_fun_sig;
(** Information about the function signature - useful in particular to
translate [Panic] *)
@@ -139,7 +147,7 @@ type bs_ctx = {
var_counter : VarId.generator;
state_var : VarId.id;
(** The current state variable, in case the function is stateful *)
- back_state_var : VarId.id;
+ back_state_vars : VarId.id RegionGroupId.Map.t;
(** The additional input state variable received by a stateful backward function.
When generating stateful functions, we generate code of the following
form:
@@ -163,16 +171,16 @@ type bs_ctx = {
(** The input parameters for the forward function corresponding to the
translated Rust inputs (no fuel, no state).
*)
- backward_inputs : var list T.RegionGroupId.Map.t;
+ backward_inputs : var list RegionGroupId.Map.t;
(** The additional input parameters for the backward functions coming
from the borrows consumed upon ending the lifetime (as a consequence
those don't include the backward state, if there is one).
*)
- backward_outputs : var list T.RegionGroupId.Map.t;
+ backward_outputs : var list RegionGroupId.Map.t;
(** The variables that the backward functions will output, corresponding
to the borrows they give back (don't include the backward state)
*)
- loop_backward_outputs : var list T.RegionGroupId.Map.t option;
+ loop_backward_outputs : var list RegionGroupId.Map.t option;
(** Same as {!backward_outputs}, but for loops (if we entered a loop).
[None] if we are not inside a loop, [Some] otherwise (and whatever
@@ -300,6 +308,13 @@ let typed_pattern_to_string (ctx : bs_ctx) (p : Pure.typed_pattern) : string =
let env = bs_ctx_to_pure_fmt_env ctx in
PrintPure.typed_pattern_to_string env p
+let ctx_get_effect_info (ctx : bs_ctx) : fun_effect_info =
+ match ctx.bid with
+ | None -> ctx.sg.fwd_info.effect_info
+ | Some bid ->
+ let back_sg = RegionGroupId.Map.find bid ctx.sg.back_sg in
+ back_sg.effect_info
+
(* TODO: move *)
let abs_to_string (ctx : bs_ctx) (abs : V.abs) : string =
let env = bs_ctx_to_fmt_env ctx in
@@ -1034,6 +1049,24 @@ let translate_fun_sig_to_decomposed (decls_ctx : C.decls_ctx)
fwd_info;
}
+let mk_output_ty_from_effect_info (effect_info : fun_effect_info) (ty : ty) : ty
+ =
+ let output =
+ if effect_info.stateful then mk_simpl_tuple_ty [ mk_state_ty; ty ] else ty
+ in
+ if effect_info.can_fail then mk_result_ty output else output
+
+(** Compute the arrow types for all the backward functions *)
+let compute_back_tys (dsg : Pure.decomposed_fun_sig) : ty list =
+ List.map
+ (fun (back_sg : back_sg_info) ->
+ let effect_info = back_sg.effect_info in
+ let inputs = dsg.fwd_inputs @ back_sg.inputs in
+ let output = mk_simpl_tuple_ty back_sg.outputs in
+ let output = mk_output_ty_from_effect_info effect_info output in
+ mk_arrows inputs output)
+ (RegionGroupId.Map.values dsg.back_sg)
+
let translate_fun_sig_from_decomposed (dsg : Pure.decomposed_fun_sig)
(gid : RegionGroupId.id option) : fun_sig =
let generics = dsg.generics in
@@ -1050,27 +1083,13 @@ let translate_fun_sig_from_decomposed (dsg : Pure.decomposed_fun_sig)
in
(* Two cases depending on whether we split the forward/backward functions
or not *)
- let mk_output_ty (effect_info : fun_effect_info) output =
- let output =
- if effect_info.stateful then mk_simpl_tuple_ty [ mk_state_ty; output ]
- else output
- in
- if effect_info.can_fail then mk_result_ty output else output
- in
+ let mk_output_ty = mk_output_ty_from_effect_info in
+
let inputs, output =
if !Config.return_back_funs then (
assert (gid = None);
(* Compute the arrow types for all the backward functions *)
- let back_tys =
- List.map
- (fun (back_sg : back_sg_info) ->
- let effect_info = back_sg.effect_info in
- let inputs = dsg.fwd_inputs @ back_sg.inputs in
- let output = mk_simpl_tuple_ty back_sg.outputs in
- let output = mk_output_ty effect_info output in
- mk_arrows inputs output)
- (RegionGroupId.Map.values dsg.back_sg)
- in
+ let back_tys = compute_back_tys dsg in
(* Group the forward output and the types of the backward functions *)
let effect_info = dsg.fwd_info.effect_info in
let output = mk_simpl_tuple_ty (dsg.fwd_output :: back_tys) in
@@ -1584,30 +1603,43 @@ and translate_panic (ctx : bs_ctx) : texpression =
* but it won't be true anymore once we translate individual blocks *)
(* If we use a state monad, we need to add a lambda for the state variable *)
(* Note that only forward functions return a state *)
- let output_ty =
- if ctx.inside_loop && Option.is_some ctx.bid then
- (* We are synthesizing the backward function of a loop body *)
- let bid = Option.get ctx.bid in
- let back_vars =
- T.RegionGroupId.Map.find bid (Option.get ctx.loop_backward_outputs)
- in
- let tys = List.map (fun (v : var) -> v.ty) back_vars in
- mk_simpl_tuple_ty tys
- else
- (* Regular function, or forward function (the forward translation for
- a loop has the same return type as the parent function)
- *)
- mk_simpl_tuple_ty ctx.sg.doutputs
- in
+ let effect_info = ctx_get_effect_info ctx in
(* TODO: we should use a [Fail] function *)
- if ctx.sg.info.effect_info.stateful then
- (* Create the [Fail] value *)
- let ret_ty = mk_simpl_tuple_ty [ mk_state_ty; output_ty ] in
- let ret_v =
- mk_result_fail_texpression_with_error_id error_failure_id ret_ty
+ let mk_output output_ty =
+ if effect_info.stateful then
+ (* Create the [Fail] value *)
+ let ret_ty = mk_simpl_tuple_ty [ mk_state_ty; output_ty ] in
+ let ret_v =
+ mk_result_fail_texpression_with_error_id error_failure_id ret_ty
+ in
+ ret_v
+ else mk_result_fail_texpression_with_error_id error_failure_id output_ty
+ in
+ if ctx.inside_loop && Option.is_some ctx.bid then
+ (* We are synthesizing the backward function of a loop body *)
+ let bid = Option.get ctx.bid in
+ let back_vars =
+ T.RegionGroupId.Map.find bid (Option.get ctx.loop_backward_outputs)
in
- ret_v
- else mk_result_fail_texpression_with_error_id error_failure_id output_ty
+ let tys = List.map (fun (v : var) -> v.ty) back_vars in
+ let output = mk_simpl_tuple_ty tys in
+ mk_output output
+ else
+ (* Regular function, or forward function (the forward translation for
+ a loop has the same return type as the parent function)
+ *)
+ match ctx.bid with
+ | None ->
+ if !Config.return_back_funs then
+ let back_tys = compute_back_tys ctx.sg in
+ let output = mk_simpl_tuple_ty (ctx.sg.fwd_output :: back_tys) in
+ mk_output output
+ else mk_output ctx.sg.fwd_output
+ | Some bid ->
+ let output =
+ mk_simpl_tuple_ty (RegionGroupId.Map.find bid ctx.sg.back_sg).outputs
+ in
+ mk_output output
(** [opt_v]: the value to return, in case we translate a forward body *)
and translate_return (ectx : C.eval_ctx) (opt_v : V.typed_value option)
@@ -1641,7 +1673,7 @@ and translate_return (ectx : C.eval_ctx) (opt_v : V.typed_value option)
* - error-monad: Return x
* - state-error: Return (state, x)
* *)
- let effect_info = ctx.sg.info.effect_info in
+ let effect_info = ctx_get_effect_info ctx in
let output =
if effect_info.stateful then
let state_rvalue = mk_state_texpression ctx.state_var in
@@ -1695,7 +1727,7 @@ and translate_return_with_loop (loop_id : V.LoopId.id) (is_continue : bool)
* effect - in particular, one manipulates a state iff the other does
* the same.
* *)
- let effect_info = ctx.sg.info.effect_info in
+ let effect_info = ctx_get_effect_info ctx in
let output =
if effect_info.stateful then
let state_rvalue = mk_state_texpression ctx.state_var in
@@ -2550,24 +2582,50 @@ and translate_intro_symbolic (ectx : C.eval_ctx) (p : S.mplace option)
and translate_forward_end (ectx : C.eval_ctx)
(loop_input_values : V.typed_value S.symbolic_value_id_map option)
- (e : S.expression) (back_e : S.expression S.region_group_id_map)
+ (fwd_e : S.expression) (back_e : S.expression S.region_group_id_map)
(ctx : bs_ctx) : texpression =
- (* Update the current state with the additional state received by the backward
- function, if needs be, and lookup the proper expression *)
- let translate_end ctx =
+ (* TODO: *)
+ assert (not !Config.return_back_funs);
+
+ let translate_one_end ctx (bid : RegionGroupId.id option) =
(* Update the current state with the additional state received by the backward
function, if needs be, and lookup the proper expression *)
let ctx, e =
match ctx.bid with
- | None -> (ctx, e)
+ | None ->
+ (* We are translating the forward function - nothing to do *)
+ (ctx, fwd_e)
| Some bid ->
- let ctx = { ctx with state_var = ctx.back_state_var } in
+ (* There are two cases here:
+ - if we split the fwd/backward functions, we simply need to update
+ the state
+ - if we don't split, we also need to wrap the expression in a
+ lambda, which introduces the additional inputs of the backward
+ function
+ *)
+ let back_state_var = RegionGroupId.Map.find bid ctx.back_state_vars in
+ let ctx = { ctx with state_var = back_state_var } in
let e = T.RegionGroupId.Map.find bid back_e in
(ctx, e)
in
translate_expression e ctx
in
+ (* There are two cases, depending on whether we are splitting the forward/backward
+ functions or not.
+
+ - if we split, then we simply need to translate the proper "end" expression,
+ that is the end of the forward function, or of the backward function we
+ are currently translating.
+ - if we don't split, then we need to translate the end of the forward
+ function (this is the value we will return) and generate the bodies
+ of the backward functions (which we will also return).
+
+ Update the current state with the additional state received by the backward
+ function, if needs be, and lookup the proper expression.
+ *)
+ let translate_end ctx = failwith "TODO" in
+
(* If we are (re-)entering a loop, we need to introduce a call to the
forward translation of the loop. *)
match loop_input_values with
@@ -2617,10 +2675,7 @@ and translate_forward_end (ectx : C.eval_ctx)
in
(* Introduce a fresh output value for the forward function *)
- let ctx, output_var =
- let output_ty = mk_simpl_tuple_ty ctx.fwd_sg.doutputs in
- fresh_var None output_ty ctx
- in
+ let ctx, output_var = fresh_var None ctx.sg.fwd_output ctx in
let args, ctx, out_pats =
let output_pat = mk_typed_pattern_from_var output_var None in
@@ -2832,7 +2887,7 @@ and translate_loop (loop : S.loop) (ctx : bs_ctx) : texpression =
(* Add the input state *)
let input_state =
- if ctx.sg.info.effect_info.stateful then Some ctx.state_var else None
+ if (ctx_get_effect_info ctx).stateful then Some ctx.state_var else None
in
(* Translate the loop body *)