summaryrefslogtreecommitdiff
path: root/compiler/SymbolicToPure.ml
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--compiler/SymbolicToPure.ml214
1 files changed, 143 insertions, 71 deletions
diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml
index 2db5f66c..922f0375 100644
--- a/compiler/SymbolicToPure.ml
+++ b/compiler/SymbolicToPure.ml
@@ -128,6 +128,7 @@ type loop_info = {
back_outputs : ty list RegionGroupId.Map.t;
(** The map from region group ids to the types of the values given back
by the corresponding loop abstractions.
+ This map is partial.
*)
back_funs : texpression option RegionGroupId.Map.t option;
(** Same as {!call_info.back_funs}.
@@ -329,6 +330,10 @@ let pure_ty_to_string (ctx : bs_ctx) (ty : ty) : string =
let env = bs_ctx_to_pure_fmt_env ctx in
PrintPure.ty_to_string env false ty
+let pure_var_to_string (ctx : bs_ctx) (v : var) : string =
+ let env = bs_ctx_to_pure_fmt_env ctx in
+ PrintPure.var_to_string env v
+
let ty_to_string (ctx : bs_ctx) (ty : T.ty) : string =
let env = bs_ctx_to_fmt_env ctx in
Print.Types.ty_to_string env ty
@@ -1251,41 +1256,56 @@ let mk_back_output_ty_from_effect_info (effect_info : fun_effect_info)
(** Compute the arrow types for all the backward functions.
If a backward function has no inputs/outputs we filter it.
+
+ We may also filter the region group ids (param [keep_rg_ids]).
+ This is useful for the loops: not all the
+ parent function region groups can be linked to a region abstraction
+ introduced by the loop.
*)
let compute_back_tys_with_info (dsg : Pure.decomposed_fun_sig)
+ ?(keep_rg_ids : RegionGroupId.Set.t option = None)
(subst : (generic_args * trait_instance_id) option) :
(back_sg_info * ty) option list =
+ let keep_rg_id =
+ match keep_rg_ids with
+ | None -> fun _ -> true
+ | Some ids -> fun id -> RegionGroupId.Set.mem id ids
+ in
List.map
- (fun (back_sg : back_sg_info) ->
- let effect_info = back_sg.effect_info in
- (* Compute the input/output types *)
- let inputs = List.map snd back_sg.inputs in
- let outputs = back_sg.outputs in
- (* Filter if necessary *)
- if !Config.simplify_merged_fwd_backs && inputs = [] && outputs = [] then
- None
- else
- let output = mk_simpl_tuple_ty outputs in
- let output =
- mk_back_output_ty_from_effect_info effect_info inputs output
- in
- let ty = mk_arrows inputs output in
- (* Substitute - TODO: normalize *)
- let ty =
- match subst with
- | None -> ty
- | Some (generics, tr_self) ->
- let subst =
- make_subst_from_generics dsg.generics generics tr_self
- in
- ty_substitute subst ty
- in
- Some (back_sg, ty))
- (RegionGroupId.Map.values dsg.back_sg)
+ (fun ((rg_id, back_sg) : RegionGroupId.id * back_sg_info) ->
+ if keep_rg_id rg_id then
+ let effect_info = back_sg.effect_info in
+ (* Compute the input/output types *)
+ let inputs = List.map snd back_sg.inputs in
+ let outputs = back_sg.outputs in
+ (* Filter if necessary *)
+ if !Config.simplify_merged_fwd_backs && inputs = [] && outputs = [] then
+ None
+ else
+ let output = mk_simpl_tuple_ty outputs in
+ let output =
+ mk_back_output_ty_from_effect_info effect_info inputs output
+ in
+ let ty = mk_arrows inputs output in
+ (* Substitute - TODO: normalize *)
+ let ty =
+ match subst with
+ | None -> ty
+ | Some (generics, tr_self) ->
+ let subst =
+ make_subst_from_generics dsg.generics generics tr_self
+ in
+ ty_substitute subst ty
+ in
+ Some (back_sg, ty)
+ else (* We ignore this region group *)
+ None)
+ (RegionGroupId.Map.bindings dsg.back_sg)
let compute_back_tys (dsg : Pure.decomposed_fun_sig)
+ ?(keep_rg_ids : RegionGroupId.Set.t option = None)
(subst : (generic_args * trait_instance_id) option) : ty option list =
- List.map (Option.map snd) (compute_back_tys_with_info dsg subst)
+ List.map (Option.map snd) (compute_back_tys_with_info dsg ~keep_rg_ids subst)
(** Compute the output type of a function, from a decomposed signature
(the output type contains the type of the value returned by the forward
@@ -1405,8 +1425,14 @@ let fresh_opt_vars (vars : (string option * ty) option list) (ctx : bs_ctx) :
(ctx, Some var))
ctx vars
-(* Introduce variables for the backward functions *)
-let fresh_back_vars_for_current_fun (ctx : bs_ctx) : bs_ctx * var option list =
+(* Introduce variables for the backward functions.
+
+ We may filter the region group ids. This is useful for the loops: not all the
+ parent function region groups can be linked to a region abstraction
+ introduced by the loop.
+*)
+let fresh_back_vars_for_current_fun (ctx : bs_ctx)
+ (keep_rg_ids : RegionGroupId.Set.t option) : bs_ctx * var option list =
(* We lookup the LLBC definition in an attempt to derive pretty names
for the backward functions. *)
let back_var_names =
@@ -1436,7 +1462,9 @@ let fresh_back_vars_for_current_fun (ctx : bs_ctx) : bs_ctx * var option list =
Some name)
(RegionGroupId.Map.bindings ctx.sg.back_sg)
in
- let back_vars = List.combine back_var_names (compute_back_tys ctx.sg None) in
+ let back_vars =
+ List.combine back_var_names (compute_back_tys ctx.sg ~keep_rg_ids None)
+ in
let back_vars =
List.map
(fun (name, ty) ->
@@ -1858,9 +1886,11 @@ let mk_emeta_symbolic_assignments (vars : var list) (values : texpression list)
let rec translate_expression (e : S.expression) (ctx : bs_ctx) : texpression =
match e with
| S.Return (ectx, opt_v) ->
- (* Remark: we can't get there if we are inside a loop *)
+ (* We reached a return.
+ Remark: we can't get there if we are inside a loop. *)
translate_return ectx opt_v ctx
| ReturnWithLoop (loop_id, is_continue) ->
+ (* We reached a return and are inside a loop. *)
translate_return_with_loop loop_id is_continue ctx
| Panic -> translate_panic ctx
| FunCall (call, e) -> translate_function_call call e ctx
@@ -1872,6 +1902,10 @@ let rec translate_expression (e : S.expression) (ctx : bs_ctx) : texpression =
translate_intro_symbolic ectx p sv v e ctx
| Meta (meta, e) -> translate_emeta meta e ctx
| ForwardEnd (ectx, loop_input_values, e, back_e) ->
+ (* Translate the end of a function, or the end of a loop.
+
+ The case where we (re-)enter a loop is handled here.
+ *)
translate_forward_end ectx loop_input_values e back_e ctx
| Loop loop -> translate_loop loop ctx
@@ -1988,7 +2022,13 @@ and translate_return_with_loop (loop_id : V.LoopId.id) (is_continue : bool)
(* Backward *)
(* Group the variables in which we stored the values we need to give back.
* See the explanations for the [SynthInput] case in [translate_end_abstraction] *)
- let backward_outputs = Option.get ctx.backward_outputs in
+ (* It can happen that we did not end any output abstraction, because the
+ loop didn't use borrows corresponding to the region we just ended.
+ If this happens, there are no backward outputs.
+ *)
+ let backward_outputs =
+ match ctx.backward_outputs with Some outputs -> outputs | None -> []
+ in
let field_values = List.map mk_texpression_from_var backward_outputs in
mk_simpl_tuple_texpression field_values
in
@@ -2118,7 +2158,7 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) :
| Some (back_sg, ty) ->
(* We insert a name for the variable only if the function
can fail: if it can fail, it means the call returns a backward
- function. Otherwise, we it directly returns the value given
+ function. Otherwise, it directly returns the value given
back by the backward function, which means we shouldn't
give it a name like "back..." (it doesn't make sense) *)
let name =
@@ -2944,12 +2984,9 @@ and translate_forward_end (ectx : C.eval_ctx)
(* We are translating the forward function - nothing to do *)
(ctx, fwd_e, fun e -> e)
| Some bid ->
- (* There are two cases here:
- - if we split the fwd/backward functions, we simply need to update
- the state.
- - if we don't split, we also need to wrap the expression in a
- lambda, which introduces the additional inputs of the backward
- function
+ (* We need to update the state, and wrap the expression in a
+ lambda, which introduces the additional inputs of the backward
+ function.
*)
let ctx =
(* Introduce variables for the inputs and the state variable
@@ -2994,15 +3031,9 @@ and translate_forward_end (ectx : C.eval_ctx)
finish e
in
- (* There are two cases, depending on whether we are splitting the forward/backward
- functions or not.
-
- - if we split, then we simply need to translate the proper "end" expression,
- that is the end of the forward function, or of the backward function we
- are currently translating.
- - if we don't split, then we need to translate the end of the forward
- function (this is the value we will return) and generate the bodies
- of the backward functions (which we will also return).
+ (* We need to translate the end of the forward
+ function (this is the value we will return) and generate the bodies
+ of the backward functions (which we will also return).
Update the current state with the additional state received by the backward
function, if needs be, and lookup the proper expression.
@@ -3013,11 +3044,31 @@ and translate_forward_end (ectx : C.eval_ctx)
let ctx, pure_fwd_var = fresh_var None ctx.sg.fwd_output ctx in
let fwd_e = translate_one_end ctx None in
- (* Introduce the backward functions. *)
+ (* If we reached a loop: if we are *inside* a loop, we need to ignore the
+ backward functions which are not associated to region abstractions.
+ *)
+ let keep_rg_ids =
+ match ctx.loop_id with
+ | None -> None
+ | Some loop_id ->
+ if ctx.inside_loop then
+ let loop_info = LoopId.Map.find loop_id ctx.loops in
+ Some
+ (RegionGroupId.Set.of_list
+ (RegionGroupId.Map.keys loop_info.back_outputs))
+ else None
+ in
+ let keep_rg_id =
+ match keep_rg_ids with
+ | None -> fun _ -> true
+ | Some ids -> fun id -> RegionGroupId.Set.mem id ids
+ in
+
let back_el =
List.map
(fun ((gid, _) : RegionGroupId.id * back_sg_info) ->
- translate_one_end ctx (Some gid))
+ if keep_rg_id gid then Some (translate_one_end ctx (Some gid))
+ else None)
(RegionGroupId.Map.bindings ctx.sg.back_sg)
in
@@ -3027,17 +3078,20 @@ and translate_forward_end (ectx : C.eval_ctx)
inputs. *)
let evaluate_backs =
List.map
- (fun (sg : back_sg_info) ->
- if !Config.simplify_merged_fwd_backs then
- sg.inputs = [] && sg.effect_info.can_fail
- else false)
- (RegionGroupId.Map.values ctx.sg.back_sg)
+ (fun ((rg_id, sg) : RegionGroupId.id * back_sg_info) ->
+ if keep_rg_id rg_id then
+ Some
+ (if !Config.simplify_merged_fwd_backs then
+ sg.inputs = [] && sg.effect_info.can_fail
+ else false)
+ else None)
+ (RegionGroupId.Map.bindings ctx.sg.back_sg)
in
(* Introduce variables for the backward functions.
We lookup the LLBC definition in an attempt to derive pretty names
for those functions. *)
- let _, back_vars = fresh_back_vars_for_current_fun ctx in
+ let _, back_vars = fresh_back_vars_for_current_fun ctx keep_rg_ids in
(* Create the return expressions *)
let vars =
@@ -3072,7 +3126,9 @@ and translate_forward_end (ectx : C.eval_ctx)
let back_vars_els =
List.filter_map
(fun (v, (eval, el)) ->
- match v with None -> None | Some v -> Some (v, eval, el))
+ match v with
+ | None -> None
+ | Some v -> Some (v, Option.get eval, Option.get el))
(List.combine back_vars (List.combine evaluate_backs back_el))
in
let e =
@@ -3154,7 +3210,16 @@ and translate_forward_end (ectx : C.eval_ctx)
backward functions of the outer function.
*)
let ctx, back_funs_map, back_funs =
- let ctx, back_vars = fresh_back_vars_for_current_fun ctx in
+ (* We need to filter the region groups which are not linked to region
+ abstractions appearing in the loop, so as not to introduce unnecessary
+ backward functions. *)
+ let keep_rg_ids =
+ RegionGroupId.Set.of_list
+ (RegionGroupId.Map.keys loop_info.back_outputs)
+ in
+ let ctx, back_vars =
+ fresh_back_vars_for_current_fun ctx (Some keep_rg_ids)
+ in
let back_funs =
List.filter_map
(fun v ->
@@ -3266,10 +3331,7 @@ and translate_loop (loop : S.loop) (ctx : bs_ctx) : texpression =
^ (Print.list_to_string (symbolic_value_to_string ctx)) svl
^ "\n- rg_to_abs\n:"
^ T.RegionGroupId.Map.show
- (fun (rids, tys) ->
- "(" ^ T.RegionId.Set.show rids ^ ", "
- ^ Print.list_to_string (ty_to_string ctx) tys
- ^ ")")
+ (Print.list_to_string (ty_to_string ctx))
loop.rg_to_given_back_tys
^ "\n"));
let ctx, _ = fresh_vars_for_symbolic_values svl ctx in
@@ -3297,7 +3359,7 @@ and translate_loop (loop : S.loop) (ctx : bs_ctx) : texpression =
let ctx = ref ctx in
let rg_to_given_back_tys =
RegionGroupId.Map.map
- (fun (_, tys) ->
+ (fun tys ->
(* The types shouldn't contain borrows - we can translate them as forward types *)
List.map
(fun ty ->
@@ -3313,11 +3375,24 @@ and translate_loop (loop : S.loop) (ctx : bs_ctx) : texpression =
let back_effect_infos, output_ty =
(* The loop backward functions consume the same additional inputs as the parent
function, but have custom outputs *)
- let back_sgs = RegionGroupId.Map.bindings ctx.sg.back_sg in
- let given_back_tys = RegionGroupId.Map.values rg_to_given_back_tys in
+ log#ldebug
+ (lazy
+ (let back_sgs = RegionGroupId.Map.bindings ctx.sg.back_sg in
+ "translate_loop:" ^ "\n- back_sgs: "
+ ^ (Print.list_to_string
+ (Print.pair_to_string RegionGroupId.to_string show_back_sg_info))
+ back_sgs
+ ^ "\n- given_back_tys: "
+ ^ (RegionGroupId.Map.to_string None
+ (Print.list_to_string (pure_ty_to_string ctx)))
+ rg_to_given_back_tys
+ ^ "\n"));
let back_info_tys =
List.map
- (fun (((id, back_sg), given_back) : (_ * back_sg_info) * ty list) ->
+ (fun ((rg_id, given_back) : RegionGroupId.id * ty list) ->
+ (* Lookup the effect information about the parent function region group
+ associated to this loop region abstraction *)
+ let back_sg = RegionGroupId.Map.find rg_id ctx.sg.back_sg in
(* Remark: the effect info of the backward function for the loop
is almost the same as for the backward function of the parent function.
Quite importantly, the fact that the function is stateful and/or can fail
@@ -3342,8 +3417,8 @@ and translate_loop (loop : S.loop) (ctx : bs_ctx) : texpression =
let ty = mk_arrows inputs output in
Some ty
in
- ((id, effect_info), ty))
- (List.combine back_sgs given_back_tys)
+ ((rg_id, effect_info), ty))
+ (RegionGroupId.Map.bindings rg_to_given_back_tys)
in
let back_info = List.map fst back_info_tys in
let back_info = RegionGroupId.Map.of_list back_info in
@@ -3445,9 +3520,6 @@ and translate_loop (loop : S.loop) (ctx : bs_ctx) : texpression =
loop_body;
}
in
- (* If we translate forward functions: the return type of a loop body is the
- same as the parent function *)
- assert (Option.is_some ctx.bid || fun_end.ty = loop_body.ty);
let ty = fun_end.ty in
{ e = loop; ty }