summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSon Ho2023-12-19 12:54:40 +0100
committerSon Ho2023-12-19 12:54:40 +0100
commit4f7bc41dcbc6187512111a81f968726452024d25 (patch)
treebc78af79887a3165dcf5d7a837992b09cc6d3071
parent116b569d1b08a57c3ad66071979a1c966fdad3a2 (diff)
Simplify SymbolicToPure.bs_ctx.{backward_outputs, loop_backward_outputs}
-rw-r--r--compiler/SymbolicToPure.ml153
-rw-r--r--compiler/Translate.ml17
2 files changed, 70 insertions, 100 deletions
diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml
index ea2082c7..93e6cb4e 100644
--- a/compiler/SymbolicToPure.ml
+++ b/compiler/SymbolicToPure.ml
@@ -109,6 +109,10 @@ type loop_info = {
(** The forward inputs are initialized at [None] *)
forward_output_no_state_no_result : var option;
(** The forward outputs are initialized at [None] *)
+ back_outputs : ty list RegionGroupId.Map.t;
+ (** The map from region group ids to the types of the values given back
+ by the corresponding loop abstractions.
+ *)
}
[@@deriving show]
@@ -187,12 +191,11 @@ type bs_ctx = {
Same remarks as for {!backward_inputs_no_state}.
*)
- backward_outputs : var list RegionGroupId.Map.t;
+ backward_outputs : var list option;
(** The variables that the backward functions will output, corresponding
to the borrows they give back (don't include the backward state).
The translation is done as follows:
- - for a given backward function, we choose a set of variables [v_i]
- when we detect the ended input abstraction which corresponds
to the backward function of the LLBC function we are translating,
and which consumed the values [consumed_i] (that we need to give
@@ -201,14 +204,20 @@ type bs_ctx = {
let v_i = consumed_i in
...
]}
- Then, upon reaching the [Return] node, we introduce:
+ where the [v_i] are fresh, and are stored in the [backward_output].
+ - Then, upon reaching the [Return] node, we introduce:
{[
- (v_i)
+ return (v_i)
]}
+
+ The option is [None] before we detect the ended input abstraction,
+ and [Some] afterwards.
*)
- loop_backward_outputs : var list RegionGroupId.Map.t option;
+ loop_backward_outputs : var list option;
(** Same as {!backward_outputs}, but for loops (if we entered a loop).
+ TODO: merge with [backward_outputs]?
+
[None] if we are not inside a loop, [Some] otherwise (and whatever
the kind of function we are translating: it will be [Some] even
though we are synthesizing a forward function).
@@ -1607,7 +1616,9 @@ let mk_emeta_symbolic_assignments (vars : var list) (values : texpression list)
let rec translate_expression (e : S.expression) (ctx : bs_ctx) : texpression =
match e with
- | S.Return (ectx, opt_v) -> translate_return ectx opt_v ctx
+ | S.Return (ectx, opt_v) ->
+ (* Remark: we can't get there if we are inside a loop *)
+ translate_return ectx opt_v ctx
| ReturnWithLoop (loop_id, is_continue) ->
translate_return_with_loop loop_id is_continue ctx
| Panic -> translate_panic ctx
@@ -1644,10 +1655,9 @@ and translate_panic (ctx : bs_ctx) : texpression =
if ctx.inside_loop && Option.is_some ctx.bid then
(* We are synthesizing the backward function of a loop body *)
let bid = Option.get ctx.bid in
- let back_vars =
- T.RegionGroupId.Map.find bid (Option.get ctx.loop_backward_outputs)
- in
- let tys = List.map (fun (v : var) -> v.ty) back_vars in
+ let loop_id = Option.get ctx.loop_id in
+ let loop = LoopId.Map.find loop_id ctx.loops in
+ let tys = RegionGroupId.Map.find bid loop.back_outputs in
let output = mk_simpl_tuple_ty tys in
mk_output output
else
@@ -1667,7 +1677,11 @@ and translate_panic (ctx : bs_ctx) : texpression =
in
mk_output output
-(** [opt_v]: the value to return, in case we translate a forward body *)
+(** [opt_v]: the value to return, in case we translate a forward body.
+
+ Remark: for now, we can't get there if we are inside a loop.
+ If inside a loop, we use {!translate_return_with_loop}.
+ *)
and translate_return (ectx : C.eval_ctx) (opt_v : V.typed_value option)
(ctx : bs_ctx) : texpression =
(* There are two cases:
@@ -1676,22 +1690,20 @@ and translate_return (ectx : C.eval_ctx) (opt_v : V.typed_value option)
- or we are translating a backward function, in which case it should be [None]
*)
(* Compute the values that we should return *without the state and the result
- * wrapper* *)
+ wrapper* *)
let output =
match ctx.bid with
| None ->
(* Forward function *)
let v = Option.get opt_v in
typed_value_to_texpression ctx ectx v
- | Some bid ->
+ | Some _ ->
(* Backward function *)
(* Sanity check *)
assert (opt_v = None);
(* Group the variables in which we stored the values we need to give back.
- * See the explanations for the [SynthInput] case in [translate_end_abstraction] *)
- let backward_outputs =
- T.RegionGroupId.Map.find bid ctx.backward_outputs
- in
+ See the explanations for the [SynthInput] case in [translate_end_abstraction] *)
+ let backward_outputs = Option.get ctx.backward_outputs in
let field_values = List.map mk_texpression_from_var backward_outputs in
mk_simpl_tuple_texpression field_values
in
@@ -1728,19 +1740,16 @@ and translate_return_with_loop (loop_id : V.LoopId.id) (is_continue : bool)
(* Forward *)
mk_texpression_from_var
(Option.get loop_info.forward_output_no_state_no_result)
- | Some bid ->
+ | Some _ ->
(* Backward *)
(* Group the variables in which we stored the values we need to give back.
* See the explanations for the [SynthInput] case in [translate_end_abstraction] *)
let backward_outputs =
- let map =
- if ctx.inside_loop then
- (* We are synthesizing a loop body *)
- Option.get ctx.loop_backward_outputs
- else (* Regular function *)
- ctx.backward_outputs
- in
- T.RegionGroupId.Map.find bid map
+ if ctx.inside_loop then
+ (* We are synthesizing a loop body *)
+ Option.get ctx.loop_backward_outputs
+ else (* Regular function *)
+ Option.get ctx.backward_outputs
in
let field_values = List.map mk_texpression_from_var backward_outputs in
mk_simpl_tuple_texpression field_values
@@ -1923,45 +1932,38 @@ and translate_end_abstraction_synth_input (ectx : C.eval_ctx) (abs : V.abs)
^ abs_to_string ctx abs ^ "\n"));
(* When we end an input abstraction, this input abstraction gets back
- * the borrows which it introduced in the context through the input
- * values: by listing those values, we get the values which are given
- * back by one of the backward functions we are synthesizing. *)
- (* Note that we don't support nested borrows for now: if we find
- * an ended synthesized input abstraction, it must be the one corresponding
- * to the backward function wer are synthesizing, it can't be the one
- * for a parent backward function.
- *)
+ the borrows which it introduced in the context through the input
+ values: by listing those values, we get the values which are given
+ back by one of the backward functions we are synthesizing.
+
+ Note that we don't support nested borrows for now: if we find
+ an ended synthesized input abstraction, it must be the one corresponding
+ to the backward function wer are synthesizing, it can't be the one
+ for a parent backward function.
+ *)
let bid = Option.get ctx.bid in
assert (rg_id = bid);
- (* The translation is done as follows:
- - for a given backward function, we choose a set of variables [v_i]
- - when we detect the ended input abstraction which corresponds
- to the backward function, and which consumed the values [consumed_i],
- we introduce:
- {[
- let v_i = consumed_i in
- ...
- ]}
- Then, when we reach the [Return] node, we introduce:
- {[
- (v_i)
- ]}
- *)
- (* First, get the given back variables.
+ (* First, introduce the given back variables.
We don't use the same given back variables if we translate a loop or
the standard body of a function.
*)
- let given_back_variables =
- let map =
- if ctx.inside_loop then
- (* We are synthesizing a loop body *)
- Option.get ctx.loop_backward_outputs
- else (* Regular function body *)
- ctx.backward_outputs
- in
- T.RegionGroupId.Map.find bid map
+ let ctx, given_back_variables =
+ if ctx.inside_loop then
+ (* We are synthesizing a loop body *)
+ let loop_id = Option.get ctx.loop_id in
+ let loop = LoopId.Map.find loop_id ctx.loops in
+ let tys = RegionGroupId.Map.find bid loop.back_outputs in
+ let vars = List.map (fun ty -> (None, ty)) tys in
+ let ctx, vars = fresh_vars vars ctx in
+ ({ ctx with loop_backward_outputs = Some vars }, vars)
+ else
+ (* Regular function body *)
+ let back_sg = RegionGroupId.Map.find bid ctx.sg.back_sg in
+ let vars = List.combine back_sg.output_names back_sg.outputs in
+ let ctx, vars = fresh_vars vars ctx in
+ ({ ctx with backward_outputs = Some vars }, vars)
in
(* Get the list of values consumed by the abstraction upon ending *)
@@ -2943,22 +2945,15 @@ and translate_loop (loop : S.loop) (ctx : bs_ctx) : texpression =
(* Compute the backward outputs *)
let ctx = ref ctx in
- let loop_backward_outputs =
+ let rg_to_given_back_tys =
T.RegionGroupId.Map.map
(fun (_, tys) ->
(* The types shouldn't contain borrows - we can translate them as forward types *)
- let vars =
- List.map
- (fun ty ->
- assert (
- not (TypesUtils.ty_has_borrows !ctx.type_ctx.type_infos ty));
- (None, ctx_translate_fwd_ty !ctx ty))
- tys
- in
- (* Introduce fresh variables *)
- let ctx', vars = fresh_vars vars !ctx in
- ctx := ctx';
- vars)
+ List.map
+ (fun ty ->
+ assert (not (TypesUtils.ty_has_borrows !ctx.type_ctx.type_infos ty));
+ ctx_translate_fwd_ty !ctx ty)
+ tys)
loop.rg_to_given_back_tys
in
let ctx = !ctx in
@@ -2966,12 +2961,7 @@ and translate_loop (loop : S.loop) (ctx : bs_ctx) : texpression =
let back_output_tys =
match ctx.bid with
| None -> None
- | Some rg_id ->
- let back_outputs =
- T.RegionGroupId.Map.find rg_id loop_backward_outputs
- in
- let back_output_tys = List.map (fun (v : var) -> v.ty) back_outputs in
- Some back_output_tys
+ | Some rg_id -> Some (T.RegionGroupId.Map.find rg_id rg_to_given_back_tys)
in
(* Add the loop information in the context *)
@@ -3013,6 +3003,7 @@ and translate_loop (loop : S.loop) (ctx : bs_ctx) : texpression =
generics;
forward_inputs = None;
forward_output_no_state_no_result = None;
+ back_outputs = rg_to_given_back_tys;
}
in
let loops = LoopId.Map.add loop_id loop_info ctx.loops in
@@ -3020,13 +3011,7 @@ and translate_loop (loop : S.loop) (ctx : bs_ctx) : texpression =
in
(* Update the context to translate the function end *)
- let ctx_end =
- {
- ctx with
- loop_id = Some loop_id;
- loop_backward_outputs = Some loop_backward_outputs;
- }
- in
+ let ctx_end = { ctx with loop_id = Some loop_id } in
let fun_end = translate_expression loop.end_expr ctx_end in
(* Update the context for the loop body *)
diff --git a/compiler/Translate.ml b/compiler/Translate.ml
index e153f4f4..0fa0202b 100644
--- a/compiler/Translate.ml
+++ b/compiler/Translate.ml
@@ -171,8 +171,7 @@ let translate_function_to_pure (trans_ctx : trans_ctx)
backward_inputs_no_state = RegionGroupId.Map.empty;
(* Initialized just below *)
backward_inputs_with_state = RegionGroupId.Map.empty;
- (* Initialized just below *)
- backward_outputs = RegionGroupId.Map.empty;
+ backward_outputs = None;
loop_backward_outputs = None;
(* Empty for now *)
calls;
@@ -234,20 +233,6 @@ let translate_function_to_pure (trans_ctx : trans_ctx)
in
let ctx = { ctx with backward_inputs_no_state; backward_inputs_with_state } in
- (* Add the backward outputs *)
- let ctx, backward_outputs =
- List.fold_left_map
- (fun ctx (region_vars : region_var_group) ->
- let gid = region_vars.id in
- let back_sg = RegionGroupId.Map.find gid sg.back_sg in
- let outputs = List.combine back_sg.output_names back_sg.outputs in
- let ctx, vars = SymbolicToPure.fresh_vars outputs ctx in
- (ctx, (gid, vars)))
- ctx regions_hierarchy
- in
- let backward_outputs = RegionGroupId.Map.of_list backward_outputs in
- let ctx = { ctx with backward_outputs } in
-
(* Translate the forward function *)
let pure_forward =
match symbolic_trans with