diff options
author | Son Ho | 2023-12-19 12:54:40 +0100 |
---|---|---|
committer | Son Ho | 2023-12-19 12:54:40 +0100 |
commit | 4f7bc41dcbc6187512111a81f968726452024d25 (patch) | |
tree | bc78af79887a3165dcf5d7a837992b09cc6d3071 | |
parent | 116b569d1b08a57c3ad66071979a1c966fdad3a2 (diff) |
Simplify SymbolicToPure.bs_ctx.{backward_outputs, loop_backward_outputs}
-rw-r--r-- | compiler/SymbolicToPure.ml | 153 | ||||
-rw-r--r-- | compiler/Translate.ml | 17 |
2 files changed, 70 insertions, 100 deletions
diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml index ea2082c7..93e6cb4e 100644 --- a/compiler/SymbolicToPure.ml +++ b/compiler/SymbolicToPure.ml @@ -109,6 +109,10 @@ type loop_info = { (** The forward inputs are initialized at [None] *) forward_output_no_state_no_result : var option; (** The forward outputs are initialized at [None] *) + back_outputs : ty list RegionGroupId.Map.t; + (** The map from region group ids to the types of the values given back + by the corresponding loop abstractions. + *) } [@@deriving show] @@ -187,12 +191,11 @@ type bs_ctx = { Same remarks as for {!backward_inputs_no_state}. *) - backward_outputs : var list RegionGroupId.Map.t; + backward_outputs : var list option; (** The variables that the backward functions will output, corresponding to the borrows they give back (don't include the backward state). The translation is done as follows: - - for a given backward function, we choose a set of variables [v_i] - when we detect the ended input abstraction which corresponds to the backward function of the LLBC function we are translating, and which consumed the values [consumed_i] (that we need to give @@ -201,14 +204,20 @@ type bs_ctx = { let v_i = consumed_i in ... ]} - Then, upon reaching the [Return] node, we introduce: + where the [v_i] are fresh, and are stored in the [backward_output]. + - Then, upon reaching the [Return] node, we introduce: {[ - (v_i) + return (v_i) ]} + + The option is [None] before we detect the ended input abstraction, + and [Some] afterwards. *) - loop_backward_outputs : var list RegionGroupId.Map.t option; + loop_backward_outputs : var list option; (** Same as {!backward_outputs}, but for loops (if we entered a loop). + TODO: merge with [backward_outputs]? + [None] if we are not inside a loop, [Some] otherwise (and whatever the kind of function we are translating: it will be [Some] even though we are synthesizing a forward function). @@ -1607,7 +1616,9 @@ let mk_emeta_symbolic_assignments (vars : var list) (values : texpression list) let rec translate_expression (e : S.expression) (ctx : bs_ctx) : texpression = match e with - | S.Return (ectx, opt_v) -> translate_return ectx opt_v ctx + | S.Return (ectx, opt_v) -> + (* Remark: we can't get there if we are inside a loop *) + translate_return ectx opt_v ctx | ReturnWithLoop (loop_id, is_continue) -> translate_return_with_loop loop_id is_continue ctx | Panic -> translate_panic ctx @@ -1644,10 +1655,9 @@ and translate_panic (ctx : bs_ctx) : texpression = if ctx.inside_loop && Option.is_some ctx.bid then (* We are synthesizing the backward function of a loop body *) let bid = Option.get ctx.bid in - let back_vars = - T.RegionGroupId.Map.find bid (Option.get ctx.loop_backward_outputs) - in - let tys = List.map (fun (v : var) -> v.ty) back_vars in + let loop_id = Option.get ctx.loop_id in + let loop = LoopId.Map.find loop_id ctx.loops in + let tys = RegionGroupId.Map.find bid loop.back_outputs in let output = mk_simpl_tuple_ty tys in mk_output output else @@ -1667,7 +1677,11 @@ and translate_panic (ctx : bs_ctx) : texpression = in mk_output output -(** [opt_v]: the value to return, in case we translate a forward body *) +(** [opt_v]: the value to return, in case we translate a forward body. + + Remark: for now, we can't get there if we are inside a loop. + If inside a loop, we use {!translate_return_with_loop}. + *) and translate_return (ectx : C.eval_ctx) (opt_v : V.typed_value option) (ctx : bs_ctx) : texpression = (* There are two cases: @@ -1676,22 +1690,20 @@ and translate_return (ectx : C.eval_ctx) (opt_v : V.typed_value option) - or we are translating a backward function, in which case it should be [None] *) (* Compute the values that we should return *without the state and the result - * wrapper* *) + wrapper* *) let output = match ctx.bid with | None -> (* Forward function *) let v = Option.get opt_v in typed_value_to_texpression ctx ectx v - | Some bid -> + | Some _ -> (* Backward function *) (* Sanity check *) assert (opt_v = None); (* Group the variables in which we stored the values we need to give back. - * See the explanations for the [SynthInput] case in [translate_end_abstraction] *) - let backward_outputs = - T.RegionGroupId.Map.find bid ctx.backward_outputs - in + See the explanations for the [SynthInput] case in [translate_end_abstraction] *) + let backward_outputs = Option.get ctx.backward_outputs in let field_values = List.map mk_texpression_from_var backward_outputs in mk_simpl_tuple_texpression field_values in @@ -1728,19 +1740,16 @@ and translate_return_with_loop (loop_id : V.LoopId.id) (is_continue : bool) (* Forward *) mk_texpression_from_var (Option.get loop_info.forward_output_no_state_no_result) - | Some bid -> + | Some _ -> (* Backward *) (* Group the variables in which we stored the values we need to give back. * See the explanations for the [SynthInput] case in [translate_end_abstraction] *) let backward_outputs = - let map = - if ctx.inside_loop then - (* We are synthesizing a loop body *) - Option.get ctx.loop_backward_outputs - else (* Regular function *) - ctx.backward_outputs - in - T.RegionGroupId.Map.find bid map + if ctx.inside_loop then + (* We are synthesizing a loop body *) + Option.get ctx.loop_backward_outputs + else (* Regular function *) + Option.get ctx.backward_outputs in let field_values = List.map mk_texpression_from_var backward_outputs in mk_simpl_tuple_texpression field_values @@ -1923,45 +1932,38 @@ and translate_end_abstraction_synth_input (ectx : C.eval_ctx) (abs : V.abs) ^ abs_to_string ctx abs ^ "\n")); (* When we end an input abstraction, this input abstraction gets back - * the borrows which it introduced in the context through the input - * values: by listing those values, we get the values which are given - * back by one of the backward functions we are synthesizing. *) - (* Note that we don't support nested borrows for now: if we find - * an ended synthesized input abstraction, it must be the one corresponding - * to the backward function wer are synthesizing, it can't be the one - * for a parent backward function. - *) + the borrows which it introduced in the context through the input + values: by listing those values, we get the values which are given + back by one of the backward functions we are synthesizing. + + Note that we don't support nested borrows for now: if we find + an ended synthesized input abstraction, it must be the one corresponding + to the backward function wer are synthesizing, it can't be the one + for a parent backward function. + *) let bid = Option.get ctx.bid in assert (rg_id = bid); - (* The translation is done as follows: - - for a given backward function, we choose a set of variables [v_i] - - when we detect the ended input abstraction which corresponds - to the backward function, and which consumed the values [consumed_i], - we introduce: - {[ - let v_i = consumed_i in - ... - ]} - Then, when we reach the [Return] node, we introduce: - {[ - (v_i) - ]} - *) - (* First, get the given back variables. + (* First, introduce the given back variables. We don't use the same given back variables if we translate a loop or the standard body of a function. *) - let given_back_variables = - let map = - if ctx.inside_loop then - (* We are synthesizing a loop body *) - Option.get ctx.loop_backward_outputs - else (* Regular function body *) - ctx.backward_outputs - in - T.RegionGroupId.Map.find bid map + let ctx, given_back_variables = + if ctx.inside_loop then + (* We are synthesizing a loop body *) + let loop_id = Option.get ctx.loop_id in + let loop = LoopId.Map.find loop_id ctx.loops in + let tys = RegionGroupId.Map.find bid loop.back_outputs in + let vars = List.map (fun ty -> (None, ty)) tys in + let ctx, vars = fresh_vars vars ctx in + ({ ctx with loop_backward_outputs = Some vars }, vars) + else + (* Regular function body *) + let back_sg = RegionGroupId.Map.find bid ctx.sg.back_sg in + let vars = List.combine back_sg.output_names back_sg.outputs in + let ctx, vars = fresh_vars vars ctx in + ({ ctx with backward_outputs = Some vars }, vars) in (* Get the list of values consumed by the abstraction upon ending *) @@ -2943,22 +2945,15 @@ and translate_loop (loop : S.loop) (ctx : bs_ctx) : texpression = (* Compute the backward outputs *) let ctx = ref ctx in - let loop_backward_outputs = + let rg_to_given_back_tys = T.RegionGroupId.Map.map (fun (_, tys) -> (* The types shouldn't contain borrows - we can translate them as forward types *) - let vars = - List.map - (fun ty -> - assert ( - not (TypesUtils.ty_has_borrows !ctx.type_ctx.type_infos ty)); - (None, ctx_translate_fwd_ty !ctx ty)) - tys - in - (* Introduce fresh variables *) - let ctx', vars = fresh_vars vars !ctx in - ctx := ctx'; - vars) + List.map + (fun ty -> + assert (not (TypesUtils.ty_has_borrows !ctx.type_ctx.type_infos ty)); + ctx_translate_fwd_ty !ctx ty) + tys) loop.rg_to_given_back_tys in let ctx = !ctx in @@ -2966,12 +2961,7 @@ and translate_loop (loop : S.loop) (ctx : bs_ctx) : texpression = let back_output_tys = match ctx.bid with | None -> None - | Some rg_id -> - let back_outputs = - T.RegionGroupId.Map.find rg_id loop_backward_outputs - in - let back_output_tys = List.map (fun (v : var) -> v.ty) back_outputs in - Some back_output_tys + | Some rg_id -> Some (T.RegionGroupId.Map.find rg_id rg_to_given_back_tys) in (* Add the loop information in the context *) @@ -3013,6 +3003,7 @@ and translate_loop (loop : S.loop) (ctx : bs_ctx) : texpression = generics; forward_inputs = None; forward_output_no_state_no_result = None; + back_outputs = rg_to_given_back_tys; } in let loops = LoopId.Map.add loop_id loop_info ctx.loops in @@ -3020,13 +3011,7 @@ and translate_loop (loop : S.loop) (ctx : bs_ctx) : texpression = in (* Update the context to translate the function end *) - let ctx_end = - { - ctx with - loop_id = Some loop_id; - loop_backward_outputs = Some loop_backward_outputs; - } - in + let ctx_end = { ctx with loop_id = Some loop_id } in let fun_end = translate_expression loop.end_expr ctx_end in (* Update the context for the loop body *) diff --git a/compiler/Translate.ml b/compiler/Translate.ml index e153f4f4..0fa0202b 100644 --- a/compiler/Translate.ml +++ b/compiler/Translate.ml @@ -171,8 +171,7 @@ let translate_function_to_pure (trans_ctx : trans_ctx) backward_inputs_no_state = RegionGroupId.Map.empty; (* Initialized just below *) backward_inputs_with_state = RegionGroupId.Map.empty; - (* Initialized just below *) - backward_outputs = RegionGroupId.Map.empty; + backward_outputs = None; loop_backward_outputs = None; (* Empty for now *) calls; @@ -234,20 +233,6 @@ let translate_function_to_pure (trans_ctx : trans_ctx) in let ctx = { ctx with backward_inputs_no_state; backward_inputs_with_state } in - (* Add the backward outputs *) - let ctx, backward_outputs = - List.fold_left_map - (fun ctx (region_vars : region_var_group) -> - let gid = region_vars.id in - let back_sg = RegionGroupId.Map.find gid sg.back_sg in - let outputs = List.combine back_sg.output_names back_sg.outputs in - let ctx, vars = SymbolicToPure.fresh_vars outputs ctx in - (ctx, (gid, vars))) - ctx regions_hierarchy - in - let backward_outputs = RegionGroupId.Map.of_list backward_outputs in - let ctx = { ctx with backward_outputs } in - (* Translate the forward function *) let pure_forward = match symbolic_trans with |