summaryrefslogtreecommitdiff
path: root/compiler/SymbolicToPure.ml
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/SymbolicToPure.ml')
-rw-r--r--compiler/SymbolicToPure.ml148
1 files changed, 132 insertions, 16 deletions
diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml
index b024f40e..120689e5 100644
--- a/compiler/SymbolicToPure.ml
+++ b/compiler/SymbolicToPure.ml
@@ -163,6 +163,13 @@ type bs_ctx = {
(** The variables that the backward functions will output, corresponding
to the borrows they give back (don't include the backward state)
*)
+ loop_backward_outputs : var list T.RegionGroupId.Map.t option;
+ (** Same as {!backward_outputs}, but for loops (if we entered a loop).
+
+ [None] if we are not inside a loop, [Some] otherwise (and whatever
+ the kind of function we are translating: it will be [Some] even
+ though we are synthesizing a forward function).
+ *)
calls : call_info V.FunCallId.Map.t;
(** The function calls we encountered so far *)
abstractions : (V.abs * texpression list) V.AbstractionId.Map.t;
@@ -255,6 +262,11 @@ let ty_to_string (ctx : bs_ctx) (ty : ty) : string =
let fmt = PrintPure.ast_to_type_formatter fmt in
PrintPure.ty_to_string fmt false ty
+let rty_to_string (ctx : bs_ctx) (ty : T.rty) : string =
+ let fmt = bs_ctx_to_ctx_formatter ctx in
+ let fmt = Print.PC.ctx_to_rtype_formatter fmt in
+ Print.PT.rty_to_string fmt ty
+
let type_decl_to_string (ctx : bs_ctx) (def : type_decl) : string =
let type_params = def.type_params in
let type_decls = ctx.type_context.llbc_type_decls in
@@ -829,7 +841,7 @@ let bs_ctx_fresh_state_var (ctx : bs_ctx) : bs_ctx * typed_pattern =
(* Return *)
(ctx, state_pat)
-let fresh_var (basename : string option) (ty : 'r T.ty) (ctx : bs_ctx) :
+let fresh_var_llbc_ty (basename : string option) (ty : 'r T.ty) (ctx : bs_ctx) :
bs_ctx * var =
(* Generate the fresh variable *)
let id, var_counter = VarId.fresh ctx.var_counter in
@@ -843,7 +855,7 @@ let fresh_var (basename : string option) (ty : 'r T.ty) (ctx : bs_ctx) :
let fresh_named_var_for_symbolic_value (basename : string option)
(sv : V.symbolic_value) (ctx : bs_ctx) : bs_ctx * var =
(* Generate the fresh variable *)
- let ctx, var = fresh_var basename sv.sv_ty ctx in
+ let ctx, var = fresh_var_llbc_ty basename sv.sv_ty ctx in
(* Insert in the map *)
let sv_to_var = V.SymbolicValueId.Map.add_strict sv.sv_id var ctx.sv_to_var in
(* Update the context *)
@@ -981,8 +993,9 @@ let rec typed_value_to_texpression (ctx : bs_ctx) (ectx : C.eval_ctx)
log#ldebug
(lazy
("typed_value_to_texpression: result:" ^ "\n- input value:\n"
- ^ V.show_typed_value v ^ "\n- translated expression:\n"
- ^ show_texpression value));
+ ^ typed_value_to_string ctx v
+ ^ "\n- translated expression:\n"
+ ^ texpression_to_string ctx value));
(* Sanity check *)
type_check_texpression ctx value;
(* Return *)
@@ -1296,7 +1309,21 @@ and translate_panic (ctx : bs_ctx) : texpression =
* but it won't be true anymore once we translate individual blocks *)
(* If we use a state monad, we need to add a lambda for the state variable *)
(* Note that only forward functions return a state *)
- let output_ty = mk_simpl_tuple_ty ctx.sg.doutputs in
+ let output_ty =
+ if ctx.inside_loop && Option.is_some ctx.bid then
+ (* We are synthesizing the backward function of a loop body *)
+ let bid = Option.get ctx.bid in
+ let back_vars =
+ T.RegionGroupId.Map.find bid (Option.get ctx.loop_backward_outputs)
+ in
+ let tys = List.map (fun (v : var) -> v.ty) back_vars in
+ mk_simpl_tuple_ty tys
+ else
+ (* Regular function, or forward function (the forward translation for
+ a loop has the same return type as the parent function)
+ *)
+ mk_simpl_tuple_ty ctx.sg.doutputs
+ in
(* TODO: we should use a [Fail] function *)
if ctx.sg.info.effect_info.stateful then
(* Create the [Fail] value *)
@@ -1373,7 +1400,14 @@ and translate_return_with_loop (loop_id : V.LoopId.id) (is_continue : bool)
(* Group the variables in which we stored the values we need to give back.
* See the explanations for the [SynthInput] case in [translate_end_abstraction] *)
let backward_outputs =
- T.RegionGroupId.Map.find bid ctx.backward_outputs
+ let map =
+ if ctx.inside_loop then
+ (* We are synthesizing a loop body *)
+ Option.get ctx.loop_backward_outputs
+ else (* Regular function *)
+ ctx.backward_outputs
+ in
+ T.RegionGroupId.Map.find bid map
in
let field_values = List.map mk_texpression_from_var backward_outputs in
mk_simpl_tuple_texpression field_values
@@ -1535,9 +1569,15 @@ and translate_end_abstraction_synth_input (ectx : C.eval_ctx) (abs : V.abs)
=
log#ldebug
(lazy
- ("translate_end_abstraction_synth_input:" ^ "\n- eval_ctx:\n"
- ^ IU.eval_ctx_to_string ectx ^ "\n- abs:\n" ^ IU.abs_to_string ectx abs
- ^ "\n"));
+ ("translate_end_abstraction_synth_input:" ^ "\n- function: "
+ ^ Print.name_to_string ctx.fun_decl.name
+ ^ "\n- rg_id: "
+ ^ T.RegionGroupId.to_string rg_id
+ ^ "\n- loop_id: "
+ ^ Print.option_to_string Pure.LoopId.to_string ctx.loop_id
+ ^ "\n- eval_ctx:\n" ^ IU.eval_ctx_to_string ectx ^ "\n- abs:\n"
+ ^ IU.abs_to_string ectx abs ^ "\n"));
+
(* When we end an input abstraction, this input abstraction gets back
* the borrows which it introduced in the context through the input
* values: by listing those values, we get the values which are given
@@ -1564,12 +1604,38 @@ and translate_end_abstraction_synth_input (ectx : C.eval_ctx) (abs : V.abs)
* (v_i)
* ]}
* *)
- (* First, get the given back variables *)
+ (* First, get the given back variables.
+
+ We don't use the same given back variables if we translate a loop or
+ the standard body of a function.
+ *)
let given_back_variables =
- T.RegionGroupId.Map.find bid ctx.backward_outputs
+ let map =
+ if ctx.inside_loop then
+ (* We are synthesizing a loop body *)
+ Option.get ctx.loop_backward_outputs
+ else (* Regular function body *)
+ ctx.backward_outputs
+ in
+ T.RegionGroupId.Map.find bid map
in
+
(* Get the list of values consumed by the abstraction upon ending *)
let consumed_values = abs_to_consumed ctx ectx abs in
+
+ log#ldebug
+ (lazy
+ ("translate_end_abstraction_synth_input:"
+ ^ "\n\n- given back variables types:\n"
+ ^ Print.list_to_string
+ (fun (v : var) -> ty_to_string ctx v.ty)
+ given_back_variables
+ ^ "\n\n- consumed values:\n"
+ ^ Print.list_to_string
+ (fun e -> texpression_to_string ctx e ^ " : " ^ ty_to_string ctx e.ty)
+ consumed_values
+ ^ "\n"));
+
(* Group the two lists *)
let variables_values = List.combine given_back_variables consumed_values in
(* Sanity check: the two lists match (same types) *)
@@ -1655,11 +1721,11 @@ and translate_end_abstraction_fun_call (ectx : C.eval_ctx) (abs : V.abs)
("\n- fun_id: " ^ A.show_fun_id fun_id ^ "\n- inputs ("
^ string_of_int (List.length inputs)
^ "): "
- ^ String.concat ", " (List.map show_texpression inputs)
+ ^ String.concat ", " (List.map (texpression_to_string ctx) inputs)
^ "\n- inst_sg.inputs ("
^ string_of_int (List.length inst_sg.inputs)
^ "): "
- ^ String.concat ", " (List.map show_ty inst_sg.inputs)));
+ ^ String.concat ", " (List.map (ty_to_string ctx) inst_sg.inputs)));
List.iter
(fun (x, ty) -> assert ((x : texpression).ty = ty))
(List.combine inputs inst_sg.inputs);
@@ -2272,6 +2338,13 @@ and translate_loop (loop : S.loop) (ctx : bs_ctx) : texpression =
loop.input_svalues
^ "\n- filtered svl: "
^ (Print.list_to_string (symbolic_value_to_string ctx)) svl
+ ^ "\n- rg_to_abs\n:"
+ ^ T.RegionGroupId.Map.show
+ (fun (rids, tys) ->
+ "(" ^ T.RegionId.Set.show rids ^ ", "
+ ^ Print.list_to_string (rty_to_string ctx) tys
+ ^ ")")
+ loop.rg_to_given_back_tys
^ "\n"));
let ctx, _ = fresh_vars_for_symbolic_values svl ctx in
ctx
@@ -2294,6 +2367,39 @@ and translate_loop (loop : S.loop) (ctx : bs_ctx) : texpression =
List.map (fun var -> mk_typed_pattern_from_var var None) inputs
in
+ (* Compute the backward outputs *)
+ let ctx = ref ctx in
+ let loop_backward_outputs =
+ T.RegionGroupId.Map.map
+ (fun (_, tys) ->
+ (* The types shouldn't contain borrows - we can translate them as forward types *)
+ let vars =
+ List.map
+ (fun ty ->
+ assert (
+ not (TypesUtils.ty_has_borrows !ctx.type_context.types_infos ty));
+ (None, ctx_translate_fwd_ty !ctx ty))
+ tys
+ in
+ (* Introduce fresh variables *)
+ let ctx', vars = fresh_vars vars !ctx in
+ ctx := ctx';
+ vars)
+ loop.rg_to_given_back_tys
+ in
+ let ctx = !ctx in
+
+ let back_output_tys =
+ match ctx.bid with
+ | None -> None
+ | Some rg_id ->
+ let back_outputs =
+ T.RegionGroupId.Map.find rg_id loop_backward_outputs
+ in
+ let back_output_tys = List.map (fun (v : var) -> v.ty) back_outputs in
+ Some back_output_tys
+ in
+
(* Add the loop information in the context *)
let ctx =
assert (not (LoopId.Map.mem loop_id ctx.loops));
@@ -2319,7 +2425,13 @@ and translate_loop (loop : S.loop) (ctx : bs_ctx) : texpression =
in
(* Update the context to translate the function end *)
- let ctx_end = { ctx with loop_id = Some loop_id } in
+ let ctx_end =
+ {
+ ctx with
+ loop_id = Some loop_id;
+ loop_backward_outputs = Some loop_backward_outputs;
+ }
+ in
let fun_end = translate_expression loop.end_expr ctx_end in
(* Update the context for the loop body *)
@@ -2339,10 +2451,13 @@ and translate_loop (loop : S.loop) (ctx : bs_ctx) : texpression =
input_state = (if !Config.use_state then Some ctx.state_var else None);
inputs;
inputs_lvs;
+ back_output_tys;
loop_body;
}
in
- assert (fun_end.ty = loop_body.ty);
+ (* If we translate forward functions: the return type of a loop body is the
+ same as the parent function *)
+ assert (Option.is_some ctx.bid || fun_end.ty = loop_body.ty);
let ty = fun_end.ty in
{ e = loop; ty }
@@ -2524,7 +2639,8 @@ let translate_fun_decl (ctx : bs_ctx) (body : S.expression option) : fun_decl =
^ "\n- back_state: "
^ String.concat ", " (List.map show_var back_state)
^ "\n- signature.inputs: "
- ^ String.concat ", " (List.map show_ty signature.inputs)));
+ ^ String.concat ", " (List.map (ty_to_string ctx) signature.inputs)
+ ));
assert (
List.for_all
(fun (var, ty) -> (var : var).ty = ty)