summaryrefslogtreecommitdiff
path: root/compiler
diff options
context:
space:
mode:
Diffstat (limited to 'compiler')
-rw-r--r--compiler/ExtractBase.ml10
-rw-r--r--compiler/InterpreterLoops.ml93
-rw-r--r--compiler/PrintPure.ml17
-rw-r--r--compiler/Pure.ml3
-rw-r--r--compiler/PureMicroPasses.ml29
-rw-r--r--compiler/PureTypeCheck.ml7
-rw-r--r--compiler/SymbolicAst.ml5
-rw-r--r--compiler/SymbolicToPure.ml148
-rw-r--r--compiler/SynthesizeSymbolic.ml18
-rw-r--r--compiler/Translate.ml1
10 files changed, 291 insertions, 40 deletions
diff --git a/compiler/ExtractBase.ml b/compiler/ExtractBase.ml
index b952d555..a9b44017 100644
--- a/compiler/ExtractBase.ml
+++ b/compiler/ExtractBase.ml
@@ -177,14 +177,16 @@ type formatter = {
indices to derive unique names for the loops for instance - if there is
exactly one loop, we don't need to use indices)
- loop id (if pertinent)
- - number of region groups (same comment as for the number of loops)
+ - number of region groups
- region group information in case of a backward function
([None] if forward function)
- pair:
- do we generate the forward function (it may have been filtered)?
- - the number of extracted backward functions (not necessarily equal
- to the number of region groups, because we may have filtered
- some of them)
+ - the number of *extracted backward functions* (same comment as for
+ the number of loops)
+ The number of extracted backward functions if not necessarily
+ equal to the number of region groups, because we may have
+ filtered some of them.
TODO: use the fun id for the assumed functions.
*)
decreases_clause_name : A.FunDeclId.id -> fun_name -> string;
diff --git a/compiler/InterpreterLoops.ml b/compiler/InterpreterLoops.ml
index 2d900b7d..11bc7a07 100644
--- a/compiler/InterpreterLoops.ml
+++ b/compiler/InterpreterLoops.ml
@@ -2785,8 +2785,11 @@ let prepare_ashared_loans_no_synth (loop_id : V.LoopId.id) (ctx : C.eval_ctx) :
get_cf_ctx_no_synth (prepare_ashared_loans (Some loop_id)) ctx
(** Compute a fixed-point for the context at the entry of the loop.
- We also return the sets of fixed ids, and the list of symbolic values
- that appear in the fixed point context.
+ We also return:
+ - the sets of fixed ids
+ - the map from region group id to the corresponding abstraction appearing
+ in the fixed point (this is useful to compute the return type of the loop
+ backward functions for instance).
Rem.: the list of symbolic values should be computable by simply exploring
the fixed point environment and listing all the symbolic values we find.
@@ -2794,7 +2797,8 @@ let prepare_ashared_loans_no_synth (loop_id : V.LoopId.id) (ctx : C.eval_ctx) :
the values which are read or modified (some symbolic values may be ignored).
*)
let compute_loop_entry_fixed_point (config : C.config) (loop_id : V.LoopId.id)
- (eval_loop_body : st_cm_fun) (ctx0 : C.eval_ctx) : C.eval_ctx * ids_sets =
+ (eval_loop_body : st_cm_fun) (ctx0 : C.eval_ctx) :
+ C.eval_ctx * ids_sets * V.abs T.RegionGroupId.Map.t =
(* The continuation for when we exit the loop - we register the
environments upon loop *reentry*, and synthesize nothing by
returning [None]
@@ -2963,7 +2967,7 @@ let compute_loop_entry_fixed_point (config : C.config) (loop_id : V.LoopId.id)
"horizontally": the symbolic values contained in the abstractions (typically
the shared values) will be preserved.
*)
- let fp =
+ let fp, rg_to_abs =
(* List the loop abstractions in the fixed-point *)
let fp_aids, add_aid, _mem_aid = V.AbstractionId.Set.mk_stateful_set () in
@@ -3066,8 +3070,10 @@ let compute_loop_entry_fixed_point (config : C.config) (loop_id : V.LoopId.id)
se, but if it doesn't happen it is bizarre and worth investigating... *)
assert (V.AbstractionId.Set.equal !aids_union !fp_aids);
- (* Merge the abstractions which need to be merged *)
+ (* Merge the abstractions which need to be merged, and compute the map from
+ region id to abstraction id *)
let fp = ref fp in
+ let rg_to_abs = ref T.RegionGroupId.Map.empty in
let _ =
T.RegionGroupId.Map.iter
(fun rg_id ids ->
@@ -3108,9 +3114,13 @@ let compute_loop_entry_fixed_point (config : C.config) (loop_id : V.LoopId.id)
id0 := id0';
()
with ValueMatchFailure _ -> raise (Failure "Unexpected"))
- ids)
+ ids;
+ (* Register the mapping *)
+ let abs = C.ctx_lookup_abs !fp !id0 in
+ rg_to_abs := T.RegionGroupId.Map.add_strict rg_id abs !rg_to_abs)
!fp_ended_aids
in
+ let rg_to_abs = !rg_to_abs in
(* Reorder the loans and borrows in the fresh abstractions in the fixed-point *)
let fp =
@@ -3164,12 +3174,12 @@ let compute_loop_entry_fixed_point (config : C.config) (loop_id : V.LoopId.id)
in
(* Return *)
- fp
+ (fp, rg_to_abs)
in
let fixed_ids = compute_fixed_ids [ fp ] in
(* Return *)
- (fp, fixed_ids)
+ (fp, fixed_ids, rg_to_abs)
(** Split an environment between the fixed abstractions, values, etc. and
the new abstractions, values, etc.
@@ -4127,7 +4137,7 @@ let eval_loop_symbolic (config : C.config) (eval_loop_body : st_cm_fun) :
let loop_id = C.fresh_loop_id () in
(* Compute the fixed point at the loop entrance *)
- let fp_ctx, fixed_ids =
+ let fp_ctx, fixed_ids, rg_to_abs =
compute_loop_entry_fixed_point config loop_id eval_loop_body ctx
in
@@ -4197,8 +4207,71 @@ let eval_loop_symbolic (config : C.config) (eval_loop_body : st_cm_fun) :
^ Print.list_to_string (symbolic_value_to_string ctx) input_svalues
^ "\n\n"));
+ (* For every abstraction introduced by the fixed-point, compute the
+ types of the given back values.
+
+ We need to explore the abstractions, looking for the mutable borrows.
+ Moreover, we list the borrows in the same order as the loans (this
+ is important in {!SymbolicToPure}, where we expect the given back
+ values to have a specific order.
+ *)
+ let compute_abs_given_back_tys (abs : V.abs) : T.RegionId.Set.t * T.rty list =
+ let is_borrow (av : V.typed_avalue) : bool =
+ match av.V.value with
+ | ABorrow _ -> true
+ | ALoan _ -> false
+ | _ -> raise (Failure "Unreachable")
+ in
+ let borrows, loans = List.partition is_borrow abs.avalues in
+
+ let borrows =
+ List.filter_map
+ (fun av ->
+ match av.V.value with
+ | V.ABorrow (V.AMutBorrow (bid, child_av)) ->
+ assert (is_aignored child_av.V.value);
+ Some (bid, child_av.V.ty)
+ | V.ABorrow (V.ASharedBorrow _) -> None
+ | _ -> raise (Failure "Unreachable"))
+ borrows
+ in
+ let borrows = ref (V.BorrowId.Map.of_list borrows) in
+
+ let loan_ids =
+ List.filter_map
+ (fun av ->
+ match av.V.value with
+ | V.ALoan (V.AMutLoan (bid, child_av)) ->
+ assert (is_aignored child_av.V.value);
+ Some bid
+ | V.ALoan (V.ASharedLoan _) -> None
+ | _ -> raise (Failure "Unreachable"))
+ loans
+ in
+
+ (* List the given back types, in the order given by the loans *)
+ let given_back_tys =
+ List.map
+ (fun lid ->
+ let bid =
+ V.BorrowId.InjSubst.find lid fp_bl_corresp.loan_to_borrow_id_map
+ in
+ let ty = V.BorrowId.Map.find bid !borrows in
+ borrows := V.BorrowId.Map.remove bid !borrows;
+ ty)
+ loan_ids
+ in
+ assert (V.BorrowId.Map.is_empty !borrows);
+
+ (abs.regions, given_back_tys)
+ in
+ let rg_to_given_back =
+ T.RegionGroupId.Map.map compute_abs_given_back_tys rg_to_abs
+ in
+
(* Put together *)
- S.synthesize_loop loop_id input_svalues fresh_sids end_expr loop_expr
+ S.synthesize_loop loop_id input_svalues fresh_sids rg_to_given_back end_expr
+ loop_expr
(** Evaluate a loop *)
let eval_loop (config : C.config) (eval_loop_body : st_cm_fun) : st_cm_fun =
diff --git a/compiler/PrintPure.ml b/compiler/PrintPure.ml
index c13ce238..532271c3 100644
--- a/compiler/PrintPure.ml
+++ b/compiler/PrintPure.ml
@@ -632,15 +632,26 @@ and loop_to_string (fmt : ast_formatter) (indent : string)
^ String.concat "; " (List.map (var_to_string type_fmt) loop.inputs)
^ "]"
in
+ let back_output_tys =
+ let tys =
+ match loop.back_output_tys with
+ | None -> ""
+ | Some tys ->
+ String.concat "; "
+ (List.map (ty_to_string (ast_to_type_formatter fmt) false) tys)
+ in
+ "back_output_tys: [" ^ tys ^ "]"
+ in
let fun_end =
texpression_to_string fmt false indent2 indent_incr loop.fun_end
in
let loop_body =
texpression_to_string fmt false indent2 indent_incr loop.loop_body
in
- "loop {\n" ^ indent1 ^ loop_inputs ^ "\n" ^ indent1 ^ "fun_end: {\n" ^ indent2
- ^ fun_end ^ "\n" ^ indent1 ^ "}\n" ^ indent1 ^ "loop_body: {\n" ^ indent2
- ^ loop_body ^ "\n" ^ indent1 ^ "}\n" ^ indent ^ "}"
+ "loop {\n" ^ indent1 ^ loop_inputs ^ "\n" ^ indent1 ^ back_output_tys ^ "\n"
+ ^ indent1 ^ "fun_end: {\n" ^ indent2 ^ fun_end ^ "\n" ^ indent1 ^ "}\n"
+ ^ indent1 ^ "loop_body: {\n" ^ indent2 ^ loop_body ^ "\n" ^ indent1 ^ "}\n"
+ ^ indent ^ "}"
and meta_to_string (fmt : ast_formatter) (meta : meta) : string =
let meta =
diff --git a/compiler/Pure.ml b/compiler/Pure.ml
index 1b0a6b5c..118aec50 100644
--- a/compiler/Pure.ml
+++ b/compiler/Pure.ml
@@ -498,6 +498,9 @@ and loop = {
inputs : var list;
inputs_lvs : typed_pattern list;
(** The inputs seen as patterns. See {!fun_body}. *)
+ back_output_tys : ty list option;
+ (** The types of the given back values, if we ar esynthesizing a backward
+ function *)
loop_body : texpression;
}
diff --git a/compiler/PureMicroPasses.ml b/compiler/PureMicroPasses.ml
index aed5b02d..25d760fe 100644
--- a/compiler/PureMicroPasses.ml
+++ b/compiler/PureMicroPasses.ml
@@ -440,13 +440,14 @@ let compute_pretty_names (def : fun_decl) : fun_decl =
input_state;
inputs;
inputs_lvs;
+ back_output_tys;
loop_body;
} =
loop
in
let ctx, fun_end = update_texpression fun_end ctx in
let ctx, loop_body = update_texpression loop_body ctx in
- let inputs = List.map (fun input -> update_var ctx input None) inputs in
+ let inputs = List.map (fun v -> update_var ctx v None) inputs in
let inputs_lvs = List.map (update_typed_pattern ctx) inputs_lvs in
let loop =
{
@@ -457,6 +458,7 @@ let compute_pretty_names (def : fun_decl) : fun_decl =
input_state;
inputs;
inputs_lvs;
+ back_output_tys;
loop_body;
}
in
@@ -1126,12 +1128,33 @@ let decompose_loops (def : fun_decl) : fun_decl * fun_decl list =
List.concat [ fuel; fwd_inputs; state; back_inputs ]
in
+ let output, doutputs =
+ match loop.back_output_tys with
+ | None ->
+ (* Forward function: the return type is the same as the
+ parent function *)
+ (fun_sig.output, fun_sig.doutputs)
+ | Some doutputs ->
+ (* Backward function: custom return type *)
+ let output = mk_simpl_tuple_ty doutputs in
+ let output =
+ if loop_effect_info.stateful then
+ mk_simpl_tuple_ty [ mk_state_ty; output ]
+ else output
+ in
+ let output =
+ if loop_effect_info.can_fail then mk_result_ty output
+ else output
+ in
+ (output, doutputs)
+ in
+
let loop_sig =
{
type_params = fun_sig.type_params;
inputs = inputs_tys;
- output = fun_sig.output;
- doutputs = fun_sig.doutputs;
+ output;
+ doutputs;
info = loop_sig_info;
}
in
diff --git a/compiler/PureTypeCheck.ml b/compiler/PureTypeCheck.ml
index 78fd077a..1871f1bc 100644
--- a/compiler/PureTypeCheck.ml
+++ b/compiler/PureTypeCheck.ml
@@ -186,7 +186,12 @@ let rec check_texpression (ctx : tc_ctx) (e : texpression) : unit =
List.iter check_branch branches)
| Loop loop ->
assert (loop.fun_end.ty = e.ty);
- assert (loop.loop_body.ty = e.ty);
+ (* If we translate forward functions, the type of the loop is the same
+ as the type of the parent expression - in case of backward functions,
+ the loop doesn't necessarily give back the same values as the parent
+ function
+ *)
+ assert (Option.is_some loop.back_output_tys || loop.loop_body.ty = e.ty);
check_texpression ctx loop.fun_end;
check_texpression ctx loop.loop_body
| Meta (_, e_next) ->
diff --git a/compiler/SymbolicAst.ml b/compiler/SymbolicAst.ml
index 7f682c9c..0e68d2fd 100644
--- a/compiler/SymbolicAst.ml
+++ b/compiler/SymbolicAst.ml
@@ -216,6 +216,11 @@ and loop = {
input_svalues : V.symbolic_value list; (** The input symbolic values *)
fresh_svalues : V.symbolic_value_id_set;
(** The symbolic values introduced by the loop fixed-point *)
+ rg_to_given_back_tys :
+ ((T.RegionId.Set.t * T.rty list) T.RegionGroupId.Map.t[@opaque]);
+ (** The map from region group ids to the types of the values given back
+ by the corresponding loop abstractions.
+ *)
end_expr : expression;
(** The end of the function (upon the moment it enters the loop) *)
loop_expr : expression; (** The symbolically executed loop body *)
diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml
index b024f40e..120689e5 100644
--- a/compiler/SymbolicToPure.ml
+++ b/compiler/SymbolicToPure.ml
@@ -163,6 +163,13 @@ type bs_ctx = {
(** The variables that the backward functions will output, corresponding
to the borrows they give back (don't include the backward state)
*)
+ loop_backward_outputs : var list T.RegionGroupId.Map.t option;
+ (** Same as {!backward_outputs}, but for loops (if we entered a loop).
+
+ [None] if we are not inside a loop, [Some] otherwise (and whatever
+ the kind of function we are translating: it will be [Some] even
+ though we are synthesizing a forward function).
+ *)
calls : call_info V.FunCallId.Map.t;
(** The function calls we encountered so far *)
abstractions : (V.abs * texpression list) V.AbstractionId.Map.t;
@@ -255,6 +262,11 @@ let ty_to_string (ctx : bs_ctx) (ty : ty) : string =
let fmt = PrintPure.ast_to_type_formatter fmt in
PrintPure.ty_to_string fmt false ty
+let rty_to_string (ctx : bs_ctx) (ty : T.rty) : string =
+ let fmt = bs_ctx_to_ctx_formatter ctx in
+ let fmt = Print.PC.ctx_to_rtype_formatter fmt in
+ Print.PT.rty_to_string fmt ty
+
let type_decl_to_string (ctx : bs_ctx) (def : type_decl) : string =
let type_params = def.type_params in
let type_decls = ctx.type_context.llbc_type_decls in
@@ -829,7 +841,7 @@ let bs_ctx_fresh_state_var (ctx : bs_ctx) : bs_ctx * typed_pattern =
(* Return *)
(ctx, state_pat)
-let fresh_var (basename : string option) (ty : 'r T.ty) (ctx : bs_ctx) :
+let fresh_var_llbc_ty (basename : string option) (ty : 'r T.ty) (ctx : bs_ctx) :
bs_ctx * var =
(* Generate the fresh variable *)
let id, var_counter = VarId.fresh ctx.var_counter in
@@ -843,7 +855,7 @@ let fresh_var (basename : string option) (ty : 'r T.ty) (ctx : bs_ctx) :
let fresh_named_var_for_symbolic_value (basename : string option)
(sv : V.symbolic_value) (ctx : bs_ctx) : bs_ctx * var =
(* Generate the fresh variable *)
- let ctx, var = fresh_var basename sv.sv_ty ctx in
+ let ctx, var = fresh_var_llbc_ty basename sv.sv_ty ctx in
(* Insert in the map *)
let sv_to_var = V.SymbolicValueId.Map.add_strict sv.sv_id var ctx.sv_to_var in
(* Update the context *)
@@ -981,8 +993,9 @@ let rec typed_value_to_texpression (ctx : bs_ctx) (ectx : C.eval_ctx)
log#ldebug
(lazy
("typed_value_to_texpression: result:" ^ "\n- input value:\n"
- ^ V.show_typed_value v ^ "\n- translated expression:\n"
- ^ show_texpression value));
+ ^ typed_value_to_string ctx v
+ ^ "\n- translated expression:\n"
+ ^ texpression_to_string ctx value));
(* Sanity check *)
type_check_texpression ctx value;
(* Return *)
@@ -1296,7 +1309,21 @@ and translate_panic (ctx : bs_ctx) : texpression =
* but it won't be true anymore once we translate individual blocks *)
(* If we use a state monad, we need to add a lambda for the state variable *)
(* Note that only forward functions return a state *)
- let output_ty = mk_simpl_tuple_ty ctx.sg.doutputs in
+ let output_ty =
+ if ctx.inside_loop && Option.is_some ctx.bid then
+ (* We are synthesizing the backward function of a loop body *)
+ let bid = Option.get ctx.bid in
+ let back_vars =
+ T.RegionGroupId.Map.find bid (Option.get ctx.loop_backward_outputs)
+ in
+ let tys = List.map (fun (v : var) -> v.ty) back_vars in
+ mk_simpl_tuple_ty tys
+ else
+ (* Regular function, or forward function (the forward translation for
+ a loop has the same return type as the parent function)
+ *)
+ mk_simpl_tuple_ty ctx.sg.doutputs
+ in
(* TODO: we should use a [Fail] function *)
if ctx.sg.info.effect_info.stateful then
(* Create the [Fail] value *)
@@ -1373,7 +1400,14 @@ and translate_return_with_loop (loop_id : V.LoopId.id) (is_continue : bool)
(* Group the variables in which we stored the values we need to give back.
* See the explanations for the [SynthInput] case in [translate_end_abstraction] *)
let backward_outputs =
- T.RegionGroupId.Map.find bid ctx.backward_outputs
+ let map =
+ if ctx.inside_loop then
+ (* We are synthesizing a loop body *)
+ Option.get ctx.loop_backward_outputs
+ else (* Regular function *)
+ ctx.backward_outputs
+ in
+ T.RegionGroupId.Map.find bid map
in
let field_values = List.map mk_texpression_from_var backward_outputs in
mk_simpl_tuple_texpression field_values
@@ -1535,9 +1569,15 @@ and translate_end_abstraction_synth_input (ectx : C.eval_ctx) (abs : V.abs)
=
log#ldebug
(lazy
- ("translate_end_abstraction_synth_input:" ^ "\n- eval_ctx:\n"
- ^ IU.eval_ctx_to_string ectx ^ "\n- abs:\n" ^ IU.abs_to_string ectx abs
- ^ "\n"));
+ ("translate_end_abstraction_synth_input:" ^ "\n- function: "
+ ^ Print.name_to_string ctx.fun_decl.name
+ ^ "\n- rg_id: "
+ ^ T.RegionGroupId.to_string rg_id
+ ^ "\n- loop_id: "
+ ^ Print.option_to_string Pure.LoopId.to_string ctx.loop_id
+ ^ "\n- eval_ctx:\n" ^ IU.eval_ctx_to_string ectx ^ "\n- abs:\n"
+ ^ IU.abs_to_string ectx abs ^ "\n"));
+
(* When we end an input abstraction, this input abstraction gets back
* the borrows which it introduced in the context through the input
* values: by listing those values, we get the values which are given
@@ -1564,12 +1604,38 @@ and translate_end_abstraction_synth_input (ectx : C.eval_ctx) (abs : V.abs)
* (v_i)
* ]}
* *)
- (* First, get the given back variables *)
+ (* First, get the given back variables.
+
+ We don't use the same given back variables if we translate a loop or
+ the standard body of a function.
+ *)
let given_back_variables =
- T.RegionGroupId.Map.find bid ctx.backward_outputs
+ let map =
+ if ctx.inside_loop then
+ (* We are synthesizing a loop body *)
+ Option.get ctx.loop_backward_outputs
+ else (* Regular function body *)
+ ctx.backward_outputs
+ in
+ T.RegionGroupId.Map.find bid map
in
+
(* Get the list of values consumed by the abstraction upon ending *)
let consumed_values = abs_to_consumed ctx ectx abs in
+
+ log#ldebug
+ (lazy
+ ("translate_end_abstraction_synth_input:"
+ ^ "\n\n- given back variables types:\n"
+ ^ Print.list_to_string
+ (fun (v : var) -> ty_to_string ctx v.ty)
+ given_back_variables
+ ^ "\n\n- consumed values:\n"
+ ^ Print.list_to_string
+ (fun e -> texpression_to_string ctx e ^ " : " ^ ty_to_string ctx e.ty)
+ consumed_values
+ ^ "\n"));
+
(* Group the two lists *)
let variables_values = List.combine given_back_variables consumed_values in
(* Sanity check: the two lists match (same types) *)
@@ -1655,11 +1721,11 @@ and translate_end_abstraction_fun_call (ectx : C.eval_ctx) (abs : V.abs)
("\n- fun_id: " ^ A.show_fun_id fun_id ^ "\n- inputs ("
^ string_of_int (List.length inputs)
^ "): "
- ^ String.concat ", " (List.map show_texpression inputs)
+ ^ String.concat ", " (List.map (texpression_to_string ctx) inputs)
^ "\n- inst_sg.inputs ("
^ string_of_int (List.length inst_sg.inputs)
^ "): "
- ^ String.concat ", " (List.map show_ty inst_sg.inputs)));
+ ^ String.concat ", " (List.map (ty_to_string ctx) inst_sg.inputs)));
List.iter
(fun (x, ty) -> assert ((x : texpression).ty = ty))
(List.combine inputs inst_sg.inputs);
@@ -2272,6 +2338,13 @@ and translate_loop (loop : S.loop) (ctx : bs_ctx) : texpression =
loop.input_svalues
^ "\n- filtered svl: "
^ (Print.list_to_string (symbolic_value_to_string ctx)) svl
+ ^ "\n- rg_to_abs\n:"
+ ^ T.RegionGroupId.Map.show
+ (fun (rids, tys) ->
+ "(" ^ T.RegionId.Set.show rids ^ ", "
+ ^ Print.list_to_string (rty_to_string ctx) tys
+ ^ ")")
+ loop.rg_to_given_back_tys
^ "\n"));
let ctx, _ = fresh_vars_for_symbolic_values svl ctx in
ctx
@@ -2294,6 +2367,39 @@ and translate_loop (loop : S.loop) (ctx : bs_ctx) : texpression =
List.map (fun var -> mk_typed_pattern_from_var var None) inputs
in
+ (* Compute the backward outputs *)
+ let ctx = ref ctx in
+ let loop_backward_outputs =
+ T.RegionGroupId.Map.map
+ (fun (_, tys) ->
+ (* The types shouldn't contain borrows - we can translate them as forward types *)
+ let vars =
+ List.map
+ (fun ty ->
+ assert (
+ not (TypesUtils.ty_has_borrows !ctx.type_context.types_infos ty));
+ (None, ctx_translate_fwd_ty !ctx ty))
+ tys
+ in
+ (* Introduce fresh variables *)
+ let ctx', vars = fresh_vars vars !ctx in
+ ctx := ctx';
+ vars)
+ loop.rg_to_given_back_tys
+ in
+ let ctx = !ctx in
+
+ let back_output_tys =
+ match ctx.bid with
+ | None -> None
+ | Some rg_id ->
+ let back_outputs =
+ T.RegionGroupId.Map.find rg_id loop_backward_outputs
+ in
+ let back_output_tys = List.map (fun (v : var) -> v.ty) back_outputs in
+ Some back_output_tys
+ in
+
(* Add the loop information in the context *)
let ctx =
assert (not (LoopId.Map.mem loop_id ctx.loops));
@@ -2319,7 +2425,13 @@ and translate_loop (loop : S.loop) (ctx : bs_ctx) : texpression =
in
(* Update the context to translate the function end *)
- let ctx_end = { ctx with loop_id = Some loop_id } in
+ let ctx_end =
+ {
+ ctx with
+ loop_id = Some loop_id;
+ loop_backward_outputs = Some loop_backward_outputs;
+ }
+ in
let fun_end = translate_expression loop.end_expr ctx_end in
(* Update the context for the loop body *)
@@ -2339,10 +2451,13 @@ and translate_loop (loop : S.loop) (ctx : bs_ctx) : texpression =
input_state = (if !Config.use_state then Some ctx.state_var else None);
inputs;
inputs_lvs;
+ back_output_tys;
loop_body;
}
in
- assert (fun_end.ty = loop_body.ty);
+ (* If we translate forward functions: the return type of a loop body is the
+ same as the parent function *)
+ assert (Option.is_some ctx.bid || fun_end.ty = loop_body.ty);
let ty = fun_end.ty in
{ e = loop; ty }
@@ -2524,7 +2639,8 @@ let translate_fun_decl (ctx : bs_ctx) (body : S.expression option) : fun_decl =
^ "\n- back_state: "
^ String.concat ", " (List.map show_var back_state)
^ "\n- signature.inputs: "
- ^ String.concat ", " (List.map show_ty signature.inputs)));
+ ^ String.concat ", " (List.map (ty_to_string ctx) signature.inputs)
+ ));
assert (
List.for_all
(fun (var, ty) -> (var : var).ty = ty)
diff --git a/compiler/SynthesizeSymbolic.ml b/compiler/SynthesizeSymbolic.ml
index 8c06717a..976b781d 100644
--- a/compiler/SynthesizeSymbolic.ml
+++ b/compiler/SynthesizeSymbolic.ml
@@ -163,10 +163,22 @@ let synthesize_forward_end (ctx : Contexts.eval_ctx)
let synthesize_loop (loop_id : V.LoopId.id)
(input_svalues : V.symbolic_value list)
- (fresh_svalues : V.SymbolicValueId.Set.t) (end_expr : expression option)
- (loop_expr : expression option) : expression option =
+ (fresh_svalues : V.SymbolicValueId.Set.t)
+ (rg_to_given_back_tys :
+ (T.RegionId.Set.t * T.rty list) T.RegionGroupId.Map.t)
+ (end_expr : expression option) (loop_expr : expression option) :
+ expression option =
match (end_expr, loop_expr) with
| None, None -> None
| Some end_expr, Some loop_expr ->
- Some (Loop { loop_id; input_svalues; fresh_svalues; end_expr; loop_expr })
+ Some
+ (Loop
+ {
+ loop_id;
+ input_svalues;
+ fresh_svalues;
+ rg_to_given_back_tys;
+ end_expr;
+ loop_expr;
+ })
| _ -> raise (Failure "Unreachable")
diff --git a/compiler/Translate.ml b/compiler/Translate.ml
index 800bac00..66280ed7 100644
--- a/compiler/Translate.ml
+++ b/compiler/Translate.ml
@@ -144,6 +144,7 @@ let translate_function_to_pure (trans_ctx : trans_ctx)
backward_inputs = T.RegionGroupId.Map.empty;
(* Empty for now *)
backward_outputs = T.RegionGroupId.Map.empty;
+ loop_backward_outputs = None;
(* Empty for now *)
calls;
abstractions;