summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--compiler/Cps.ml4
-rw-r--r--compiler/Interpreter.ml44
-rw-r--r--compiler/InterpreterBorrowsCore.ml6
-rw-r--r--compiler/InterpreterExpressions.mli2
-rw-r--r--compiler/InterpreterLoops.ml116
-rw-r--r--compiler/InterpreterUtils.ml46
-rw-r--r--compiler/Pure.ml2
-rw-r--r--compiler/PureUtils.ml2
-rw-r--r--compiler/SCC.ml2
-rw-r--r--compiler/SymbolicAst.ml104
-rw-r--r--compiler/SymbolicToPure.ml808
-rw-r--r--compiler/SynthesizeSymbolic.ml10
-rw-r--r--compiler/Translate.ml32
-rw-r--r--compiler/Values.ml11
14 files changed, 877 insertions, 312 deletions
diff --git a/compiler/Cps.ml b/compiler/Cps.ml
index 1b5164a1..8763ff78 100644
--- a/compiler/Cps.ml
+++ b/compiler/Cps.ml
@@ -17,7 +17,7 @@ type statement_eval_res =
| Return
| Panic
| LoopReturn (** We reached a return statement *while inside a loop* *)
- | EndEnterLoop of V.typed_value V.SymbolicValueId.Map.t
+ | EndEnterLoop of V.loop_id * V.typed_value V.SymbolicValueId.Map.t
(** When we enter a loop, we delegate the end of the function is
synthesized with a call to the loop translation. We use this
evaluation result to transmit the fact that we end evaluation
@@ -26,7 +26,7 @@ type statement_eval_res =
We provide the list of values for the translated loop function call
(or to be more precise the input values instantiation).
*)
- | EndContinue of V.typed_value V.SymbolicValueId.Map.t
+ | EndContinue of V.loop_id * V.typed_value V.SymbolicValueId.Map.t
(** For loop translations: we end with a continue (i.e., a recursive call
to the translation for the loop body).
diff --git a/compiler/Interpreter.ml b/compiler/Interpreter.ml
index 7a85461e..d5032756 100644
--- a/compiler/Interpreter.ml
+++ b/compiler/Interpreter.ml
@@ -132,8 +132,9 @@ let initialize_symbolic_context_for_fun (type_context : C.type_context)
*)
let evaluate_function_symbolic_synthesize_backward_from_return
(config : C.config) (fdef : A.fun_decl) (inst_sg : A.inst_fun_sig)
- (back_id : T.RegionGroupId.id) (entered_loop : bool) (inside_loop : bool)
- (ctx : C.eval_ctx) : SA.expression option =
+ (back_id : T.RegionGroupId.id) (loop_id : V.LoopId.id option)
+ (inside_loop : bool) (ctx : C.eval_ctx) : SA.expression option =
+ let entered_loop = Option.is_some loop_id in
(* We need to instantiate the function signature - to retrieve
* the return type. Note that it is important to re-generate
* an instantiation of the signature, so that we use fresh
@@ -238,9 +239,7 @@ let evaluate_function_symbolic_synthesize_backward_from_return
| V.Loop (_, rg_id', kind) ->
let rg_id' = Option.get rg_id' in
let is_ret =
- match kind with
- | V.LoopSynthInput -> true
- | V.LoopSynthRet -> false
+ match kind with V.LoopSynthInput -> true | V.LoopCall -> false
in
rg_id' = back_id && is_ret
| _ -> false
@@ -255,7 +254,12 @@ let evaluate_function_symbolic_synthesize_backward_from_return
cf target_abs_ids
in
(* Generate the Return node *)
- let cf_return : m_fun = fun ctx -> Some (SA.Return (ctx, None)) in
+ let cf_return : m_fun =
+ fun ctx ->
+ match loop_id with
+ | None -> Some (SA.Return (ctx, None))
+ | Some loop_id -> Some (SA.ReturnWithLoop (loop_id, inside_loop))
+ in
(* Apply *)
cf_end_target_abs cf_return ctx
in
@@ -285,6 +289,7 @@ let evaluate_function_symbolic (synthesize : bool)
(* Create the continuation to finish the evaluation *)
let config = C.mk_config C.SymbolicMode in
let cf_finish res ctx =
+ let ctx0 = ctx in
log#ldebug
(lazy
("evaluate_function_symbolic: cf_finish: "
@@ -321,7 +326,6 @@ let evaluate_function_symbolic (synthesize : bool)
abstractions to consume the return value, then end all the
abstractions up to the one in which we are interested.
*)
- let entered_loop = false in
let inside_loop =
match res with
| Return -> false
@@ -331,7 +335,7 @@ let evaluate_function_symbolic (synthesize : bool)
let finish_back_eval back_id =
Option.get
(evaluate_function_symbolic_synthesize_backward_from_return config
- fdef inst_sg back_id entered_loop inside_loop ctx)
+ fdef inst_sg back_id None inside_loop ctx)
in
let back_el =
T.RegionGroupId.mapi
@@ -340,11 +344,18 @@ let evaluate_function_symbolic (synthesize : bool)
in
let back_el = T.RegionGroupId.Map.of_list back_el in
(* Put everything together *)
- S.synthesize_forward_end None fwd_e back_el
+ S.synthesize_forward_end ctx0 None fwd_e back_el
else None
- | EndEnterLoop loop_input_values | EndContinue loop_input_values ->
+ | EndEnterLoop (loop_id, loop_input_values)
+ | EndContinue (loop_id, loop_input_values) ->
(* Similar to [Return]: we have to play different endings *)
if synthesize then
+ let inside_loop =
+ match res with
+ | EndEnterLoop _ -> false
+ | EndContinue _ -> true
+ | _ -> raise (Failure "Unreachable")
+ in
(* Forward translation *)
let fwd_e =
(* Pop the frame - there is no returned value to pop: in the
@@ -353,7 +364,7 @@ let evaluate_function_symbolic (synthesize : bool)
let cf_pop = pop_frame config pop_return_value in
(* Generate the Return node *)
let cf_return _ret_value : m_fun =
- fun ctx -> Some (SA.Return (ctx, None))
+ fun _ctx -> Some (SA.ReturnWithLoop (loop_id, inside_loop))
in
(* Apply *)
cf_pop cf_return ctx
@@ -363,17 +374,10 @@ let evaluate_function_symbolic (synthesize : bool)
abstractions to consume the return value, then end all the
abstractions up to the one in which we are interested.
*)
- let entered_loop = true in
- let inside_loop =
- match res with
- | EndEnterLoop _ -> false
- | EndContinue _ -> true
- | _ -> raise (Failure "Unreachable")
- in
let finish_back_eval back_id =
Option.get
(evaluate_function_symbolic_synthesize_backward_from_return config
- fdef inst_sg back_id entered_loop inside_loop ctx)
+ fdef inst_sg back_id (Some loop_id) inside_loop ctx)
in
let back_el =
T.RegionGroupId.mapi
@@ -382,7 +386,7 @@ let evaluate_function_symbolic (synthesize : bool)
in
let back_el = T.RegionGroupId.Map.of_list back_el in
(* Put everything together *)
- S.synthesize_forward_end (Some loop_input_values) fwd_e back_el
+ S.synthesize_forward_end ctx0 (Some loop_input_values) fwd_e back_el
else None
| Panic ->
(* Note that as we explore all the execution branches, one of
diff --git a/compiler/InterpreterBorrowsCore.ml b/compiler/InterpreterBorrowsCore.ml
index e3ad26f9..55365043 100644
--- a/compiler/InterpreterBorrowsCore.ml
+++ b/compiler/InterpreterBorrowsCore.ml
@@ -186,9 +186,9 @@ let projection_contains (ty1 : T.rty) (rset1 : T.RegionId.Set.t) (ty2 : T.rty)
The loan is referred to by a borrow id.
- Rem.: if the {!g_loan_content} is {!Concrete}, the {!abs_or_var_id} is not
- necessarily {!VarId} or {!DummyVarId}: there can be concrete loans in
- abstractions (in the shared values).
+ Rem.: if the {!InterpreterUtils.g_loan_content} is {!constructor:Aeneas.InterpreterUtils.concrete_or_abs.Concrete},
+ the {!InterpreterUtils.abs_or_var_id} is not necessarily {!constructor:Aeneas.InterpreterUtils.abs_or_var_id.VarId} or
+ {!constructor:Aeneas.InterpreterUtils.abs_or_var_id.DummyVarId}: there can be concrete loans in abstractions (in the shared values).
*)
let lookup_loan_opt (ek : exploration_kind) (l : V.BorrowId.id)
(ctx : C.eval_ctx) : (abs_or_var_id * g_loan_content) option =
diff --git a/compiler/InterpreterExpressions.mli b/compiler/InterpreterExpressions.mli
index fa717041..3beba610 100644
--- a/compiler/InterpreterExpressions.mli
+++ b/compiler/InterpreterExpressions.mli
@@ -64,7 +64,7 @@ val eval_operands :
Transmits the computed rvalue to the received continuation.
- Note that this function fails on {!constructor:E.Discriminant}: discriminant
+ Note that this function fails on {!constructor:Aeneas.Expressions.rvalue.Discriminant}: discriminant
reads should have been eliminated from the AST.
*)
val eval_rvalue_not_global :
diff --git a/compiler/InterpreterLoops.ml b/compiler/InterpreterLoops.ml
index 29e68ca0..82a9011c 100644
--- a/compiler/InterpreterLoops.ml
+++ b/compiler/InterpreterLoops.ml
@@ -1766,7 +1766,7 @@ let match_ctxs (check_equiv : bool) (fixed_ids : ids_sets) (ctx0 : C.eval_ctx)
(ctx1 : C.eval_ctx) : ids_maps option =
log#ldebug
(lazy
- ("ctxs_are_equivalent:\n\n- fixed_ids:\n" ^ show_ids_sets fixed_ids
+ ("match_ctxs:\n\n- fixed_ids:\n" ^ show_ids_sets fixed_ids
^ "\n\n- ctx0:\n"
^ eval_ctx_to_string_no_filter ctx0
^ "\n\n- ctx1:\n"
@@ -1898,8 +1898,8 @@ let match_ctxs (check_equiv : bool) (fixed_ids : ids_sets) (ctx0 : C.eval_ctx)
let rec match_envs (env0 : C.env) (env1 : C.env) : unit =
log#ldebug
(lazy
- ("ctxs_are_equivalent: match_envs:\n\n- fixed_ids:\n"
- ^ show_ids_sets fixed_ids ^ "\n\n- rid_map: "
+ ("match_ctxs: match_envs:\n\n- fixed_ids:\n" ^ show_ids_sets fixed_ids
+ ^ "\n\n- rid_map: "
^ T.RegionId.InjSubst.show_t !rid_map
^ "\n- blid_map: "
^ V.BorrowId.InjSubst.show_t !blid_map
@@ -1924,7 +1924,7 @@ let match_ctxs (check_equiv : bool) (fixed_ids : ids_sets) (ctx0 : C.eval_ctx)
assert (b0 = b1);
assert (v0 = v1);
(* The ids present in the left value must be fixed *)
- let ids = compute_typed_value_ids v0 in
+ let ids, _ = compute_typed_value_ids v0 in
assert ((not S.check_equiv) || ids_are_fixed ids);
(* Continue *)
match_envs env0' env1')
@@ -1939,22 +1939,20 @@ let match_ctxs (check_equiv : bool) (fixed_ids : ids_sets) (ctx0 : C.eval_ctx)
(* Continue *)
match_envs env0' env1'
| C.Abs abs0 :: env0', C.Abs abs1 :: env1' ->
- log#ldebug (lazy "ctxs_are_equivalent: match_envs: matching abs");
+ log#ldebug (lazy "match_ctxs: match_envs: matching abs");
(* Same as for the dummy values: there are two cases *)
if V.AbstractionId.Set.mem abs0.abs_id fixed_ids.aids then (
- log#ldebug
- (lazy "ctxs_are_equivalent: match_envs: matching abs: fixed abs");
+ log#ldebug (lazy "match_ctxs: match_envs: matching abs: fixed abs");
(* Still in the prefix: the abstractions must be the same *)
assert (abs0 = abs1);
(* Their ids must be fixed *)
- let ids = compute_abs_ids abs0 in
+ let ids, _ = compute_abs_ids abs0 in
assert ((not S.check_equiv) || ids_are_fixed ids);
(* Continue *)
match_envs env0' env1')
else (
log#ldebug
- (lazy
- "ctxs_are_equivalent: match_envs: matching abs: not fixed abs");
+ (lazy "match_ctxs: match_envs: matching abs: not fixed abs");
(* Match the values *)
match_abstractions abs0 abs1;
(* Continue *)
@@ -2115,10 +2113,17 @@ let loop_join_origin_with_continue_ctxs (config : C.config)
((old_ctx, ctxl), !joined_ctx)
(** Compute a fixed-point for the context at the entry of the loop.
- We also return the sets of fixed ids.
+ We also return the sets of fixed ids, and the list of symbolic values
+ that appear in the fixed point context.
+
+ Rem.: the list of symbolic values should be computable by simply exploring
+ the fixed point environment and listing all the symbolic values we find.
+ In the future, we might want to do something more precise, by listing only
+ the values which are read or modified (some symbolic values may be ignored).
*)
let compute_loop_entry_fixed_point (config : C.config) (loop_id : V.LoopId.id)
- (eval_loop_body : st_cm_fun) (ctx0 : C.eval_ctx) : C.eval_ctx * ids_sets =
+ (eval_loop_body : st_cm_fun) (ctx0 : C.eval_ctx) :
+ C.eval_ctx * ids_sets * V.symbolic_value list =
(* The continuation for when we exit the loop - we register the
environments upon loop *reentry*, and synthesize nothing by
returning [None]
@@ -2147,7 +2152,7 @@ let compute_loop_entry_fixed_point (config : C.config) (loop_id : V.LoopId.id)
the borrows/loans which end during the first loop iteration (we do
one loop iteration, then set it to [Some].
*)
- let fixed_ids = ref None in
+ let fixed_ids : ids_sets option ref = ref None in
(* Join the contexts at the loop entry *)
let join_ctxs (ctx0 : C.eval_ctx) : C.eval_ctx =
@@ -2159,8 +2164,8 @@ let compute_loop_entry_fixed_point (config : C.config) (loop_id : V.LoopId.id)
match !fixed_ids with
| Some _ -> ctx0
| None ->
- let old_ids = compute_context_ids ctx0 in
- let new_ids = compute_contexts_ids !ctxs in
+ let old_ids, _ = compute_context_ids ctx0 in
+ let new_ids, _ = compute_contexts_ids !ctxs in
let blids = V.BorrowId.Set.diff old_ids.blids new_ids.blids in
let aids = V.AbstractionId.Set.diff old_ids.aids new_ids.aids in
(* End those borrows and abstractions *)
@@ -2181,7 +2186,7 @@ let compute_loop_entry_fixed_point (config : C.config) (loop_id : V.LoopId.id)
will detect them and ask to end them) we do it preemptively.
*)
ctxs := List.map (end_borrows_abs blids aids) !ctxs;
- fixed_ids := Some (compute_context_ids ctx0);
+ fixed_ids := Some (fst (compute_context_ids ctx0));
ctx0
in
@@ -2200,10 +2205,10 @@ let compute_loop_entry_fixed_point (config : C.config) (loop_id : V.LoopId.id)
(* Compute the set of fixed ids - for the symbolic ids, we compute the
intersection of ids between the original environment and [ctx1]
and [ctx2] *)
- let fixed_ids = compute_context_ids (Option.get !ctx_upon_entry) in
+ let fixed_ids, _ = compute_context_ids (Option.get !ctx_upon_entry) in
let { aids; blids; borrow_ids; loan_ids; dids; rids; sids } = fixed_ids in
- let fixed_ids1 = compute_context_ids ctx1 in
- let fixed_ids2 = compute_context_ids ctx2 in
+ let fixed_ids1, _ = compute_context_ids ctx1 in
+ let fixed_ids2, _ = compute_context_ids ctx2 in
let sids =
V.SymbolicValueId.Set.inter sids
(V.SymbolicValueId.Set.inter fixed_ids1.sids fixed_ids2.sids)
@@ -2211,13 +2216,30 @@ let compute_loop_entry_fixed_point (config : C.config) (loop_id : V.LoopId.id)
let fixed_ids = { aids; blids; borrow_ids; loan_ids; dids; rids; sids } in
fixed_ids
in
- let equiv_ctxs (ctx1 : C.eval_ctx) (ctx2 : C.eval_ctx) : bool =
+ let equiv_ctxs (ctx1 : C.eval_ctx) (ctx2 : C.eval_ctx) :
+ V.symbolic_value list option =
let fixed_ids = compute_fixed_ids ctx1 ctx2 in
- ctxs_are_equivalent fixed_ids ctx1 ctx2
+ let check_equivalent = true in
+ match match_ctxs check_equivalent fixed_ids ctx1 ctx2 with
+ | None -> None
+ | Some maps ->
+ (* Compute the list of symbolic value ids *)
+ let sidl =
+ List.map fst (V.SymbolicValueId.Map.bindings maps.sid_to_value_map)
+ in
+ (* Lookup the symbolic value themselves *)
+ let _, ids_to_values = compute_context_ids ctx1 in
+ let svl =
+ List.map
+ (fun sid ->
+ V.SymbolicValueId.Map.find sid ids_to_values.sids_to_values)
+ sidl
+ in
+ Some svl
in
let max_num_iter = Config.loop_fixed_point_max_num_iters in
let rec compute_fixed_point (ctx : C.eval_ctx) (i0 : int) (i : int) :
- C.eval_ctx =
+ C.eval_ctx * V.symbolic_value list =
if i = 0 then
raise
(Failure
@@ -2240,12 +2262,20 @@ let compute_loop_entry_fixed_point (config : C.config) (loop_id : V.LoopId.id)
^ "\n\n"));
(* Check if we reached a fixed point: if not, iterate *)
- if equiv_ctxs ctx ctx1 then ctx1 else compute_fixed_point ctx1 i0 (i - 1)
+ match equiv_ctxs ctx ctx1 with
+ | Some svl -> (ctx1, svl)
+ | None -> compute_fixed_point ctx1 i0 (i - 1)
in
- let fp = compute_fixed_point ctx0 max_num_iter max_num_iter in
+ let fp, fp_svl = compute_fixed_point ctx0 max_num_iter max_num_iter in
(* Make sure we have exactly one loop abstraction per function region (merge
- abstractions accordingly) *)
+ abstractions accordingly).
+
+ Rem.: this shouldn't impact the set of symbolic value ids (because we
+ already merged abstractions "vertically" and are now merging them
+ "horizontally": the symbolic values contained in the abstractions (typically
+ the shared values) will be preserved.
+ *)
let fp =
(* List the loop abstractions in the fixed-point, and set all the abstractions
as endable *)
@@ -2320,7 +2350,7 @@ let compute_loop_entry_fixed_point (config : C.config) (loop_id : V.LoopId.id)
InterpreterBorrows.end_abstraction_no_synth config abs_id ctx
in
(* Explore the context, and check which abstractions are not there anymore *)
- let ids = compute_context_ids ctx in
+ let ids, _ = compute_context_ids ctx in
let ended_ids = V.AbstractionId.Set.diff !fp_aids ids.aids in
add_ended_aids rg_id ended_ids)
ctx.region_groups
@@ -2400,17 +2430,26 @@ let compute_loop_entry_fixed_point (config : C.config) (loop_id : V.LoopId.id)
in
let fp = update_kinds_can_end false !fp in
- (* Check that we still have a fixed point - we simply call [compute_fixed_point]
+ (* Sanity check: we still have a fixed point - we simply call [compute_fixed_point]
while allowing exactly one iteration to see if it fails *)
let _ = compute_fixed_point (update_kinds_can_end true fp) 1 1 in
+ (* Sanity check: the set of symbolic value ids is still valid *)
+ let _ =
+ let all_ids, _ = compute_context_ids fp in
+ let fp_sids =
+ V.SymbolicValueId.Set.of_list (List.map (fun sv -> sv.V.sv_id) fp_svl)
+ in
+ assert (V.SymbolicValueId.Set.subset fp_sids all_ids.sids)
+ in
+
(* Return *)
fp
in
let fixed_ids = compute_fixed_ids (Option.get !ctx_upon_entry) fp in
(* Return *)
- (fp, fixed_ids)
+ (fp, fixed_ids, fp_svl)
(** Split an environment between the fixed abstractions, values, etc. and
the new abstractions, values, etc.
@@ -2596,7 +2635,7 @@ let compute_fixed_point_id_correspondance (fixed_ids : ids_sets)
to the same set of source loans and borrows *)
List.iter
(fun abs ->
- let ids = compute_abs_ids abs in
+ let ids, _ = compute_abs_ids abs in
(* Map the *loan* ids (we just match the corresponding *loans* ) *)
let loan_ids =
V.BorrowId.Set.map
@@ -3031,7 +3070,7 @@ let match_ctx_with_target (config : C.config) (loop_id : V.LoopId.id)
| V.Loop (loop_id', rg_id, kind) ->
assert (loop_id' = loop_id);
assert (kind = V.LoopSynthInput);
- let kind = V.Loop (loop_id, rg_id, V.LoopSynthRet) in
+ let kind = V.Loop (loop_id, rg_id, V.LoopCall) in
let abs = { abs with kind } in
super#visit_abs env abs
| _ -> super#visit_abs env abs
@@ -3067,8 +3106,8 @@ let match_ctx_with_target (config : C.config) (loop_id : V.LoopId.id)
(* Continue *)
cc
(cf
- (if is_loop_entry then EndEnterLoop input_values
- else EndContinue input_values))
+ (if is_loop_entry then EndEnterLoop (loop_id, input_values)
+ else EndContinue (loop_id, input_values)))
src_ctx
in
@@ -3130,7 +3169,7 @@ let eval_loop_symbolic (config : C.config) (eval_loop_body : st_cm_fun) :
(* Compute a fresh loop id *)
let loop_id = C.fresh_loop_id () in
(* Compute the fixed point at the loop entrance *)
- let fp_ctx, fixed_ids =
+ let fp_ctx, fixed_ids, input_svalues =
compute_loop_entry_fixed_point config loop_id eval_loop_body ctx
in
@@ -3179,8 +3218,17 @@ let eval_loop_symbolic (config : C.config) (eval_loop_body : st_cm_fun) :
in
let loop_expr = eval_loop_body cf_loop fp_ctx in
+ (* Compute the list of fresh symbolic values *)
+ let fresh_sids =
+ let input_sids =
+ V.SymbolicValueId.Set.of_list
+ (List.map (fun sv -> sv.V.sv_id) input_svalues)
+ in
+ V.SymbolicValueId.Set.diff input_sids fixed_ids.sids
+ in
+
(* Put together *)
- S.synthesize_loop loop_id end_expr loop_expr
+ S.synthesize_loop loop_id input_svalues fresh_sids end_expr loop_expr
(** Evaluate a loop *)
let eval_loop (config : C.config) (eval_loop_body : st_cm_fun) : st_cm_fun =
diff --git a/compiler/InterpreterUtils.ml b/compiler/InterpreterUtils.ml
index ee8ea212..408184e0 100644
--- a/compiler/InterpreterUtils.ml
+++ b/compiler/InterpreterUtils.ml
@@ -279,6 +279,14 @@ type ids_sets = {
}
[@@deriving show]
+(** See {!compute_typed_value_ids}, {!compute_context_ids}, etc.
+
+ TODO: there misses information.
+ *)
+type ids_to_values = {
+ sids_to_values : V.symbolic_value V.SymbolicValueId.Map.t;
+}
+
let compute_ids () =
let blids = ref V.BorrowId.Set.empty in
let borrow_ids = ref V.BorrowId.Set.empty in
@@ -287,6 +295,7 @@ let compute_ids () =
let dids = ref C.DummyVarId.Set.empty in
let rids = ref T.RegionId.Set.empty in
let sids = ref V.SymbolicValueId.Set.empty in
+ let sids_to_values = ref V.SymbolicValueId.Map.empty in
let get_ids () =
{
@@ -299,9 +308,10 @@ let compute_ids () =
sids = !sids;
}
in
+ let get_ids_to_values () = { sids_to_values = !sids_to_values } in
let obj =
object
- inherit [_] C.iter_eval_ctx
+ inherit [_] C.iter_eval_ctx as super
method! visit_dummy_var_id _ did = dids := C.DummyVarId.Set.add did !dids
method! visit_borrow_id _ id =
@@ -317,37 +327,45 @@ let compute_ids () =
method! visit_region_id _ id = rids := T.RegionId.Set.add id !rids
+ method! visit_symbolic_value env sv =
+ sids := V.SymbolicValueId.Set.add sv.sv_id !sids;
+ sids_to_values := V.SymbolicValueId.Map.add sv.sv_id sv !sids_to_values;
+ super#visit_symbolic_value env sv
+
method! visit_symbolic_value_id _ id =
+ (* TODO: can we get there without going through [visit_symbolic_value] first? *)
sids := V.SymbolicValueId.Set.add id !sids
end
in
- (obj, get_ids)
+ (obj, get_ids, get_ids_to_values)
(** Compute the sets of ids found in a list of typed values. *)
-let compute_typed_values_ids (xl : V.typed_value list) : ids_sets =
- let compute, get_ids = compute_ids () in
+let compute_typed_values_ids (xl : V.typed_value list) :
+ ids_sets * ids_to_values =
+ let compute, get_ids, get_ids_to_values = compute_ids () in
List.iter (compute#visit_typed_value ()) xl;
- get_ids ()
+ (get_ids (), get_ids_to_values ())
(** Compute the sets of ids found in a typed value. *)
-let compute_typed_value_ids (x : V.typed_value) : ids_sets =
+let compute_typed_value_ids (x : V.typed_value) : ids_sets * ids_to_values =
compute_typed_values_ids [ x ]
(** Compute the sets of ids found in a list of abstractions. *)
-let compute_absl_ids (xl : V.abs list) : ids_sets =
- let compute, get_ids = compute_ids () in
+let compute_absl_ids (xl : V.abs list) : ids_sets * ids_to_values =
+ let compute, get_ids, get_ids_to_values = compute_ids () in
List.iter (compute#visit_abs ()) xl;
- get_ids ()
+ (get_ids (), get_ids_to_values ())
(** Compute the sets of ids found in an abstraction. *)
-let compute_abs_ids (x : V.abs) : ids_sets = compute_absl_ids [ x ]
+let compute_abs_ids (x : V.abs) : ids_sets * ids_to_values =
+ compute_absl_ids [ x ]
(** Compute the sets of ids found in a list of contexts. *)
-let compute_contexts_ids (ctxl : C.eval_ctx list) : ids_sets =
- let compute, get_ids = compute_ids () in
+let compute_contexts_ids (ctxl : C.eval_ctx list) : ids_sets * ids_to_values =
+ let compute, get_ids, get_ids_to_values = compute_ids () in
List.iter (compute#visit_eval_ctx ()) ctxl;
- get_ids ()
+ (get_ids (), get_ids_to_values ())
(** Compute the sets of ids found in a context. *)
-let compute_context_ids (ctx : C.eval_ctx) : ids_sets =
+let compute_context_ids (ctx : C.eval_ctx) : ids_sets * ids_to_values =
compute_contexts_ids [ ctx ]
diff --git a/compiler/Pure.ml b/compiler/Pure.ml
index e6106eed..10ce876f 100644
--- a/compiler/Pure.ml
+++ b/compiler/Pure.ml
@@ -15,7 +15,7 @@ module FunDeclId = A.FunDeclId
module GlobalDeclId = A.GlobalDeclId
(** We redefine identifiers for loop: in {Values}, the identifiers are global
- (they monotonically increase across functions) while in {!Pure} we want
+ (they monotonically increase across functions) while in {!module:Pure} we want
the indices to start at 0 for every function.
*)
module LoopId = IdGen ()
diff --git a/compiler/PureUtils.ml b/compiler/PureUtils.ml
index 0e245f35..b5c9b686 100644
--- a/compiler/PureUtils.ml
+++ b/compiler/PureUtils.ml
@@ -114,7 +114,7 @@ let fun_sig_substitute (tsubst : TypeVarId.id -> ty) (sg : fun_sig) :
We only look for outer monadic let-bindings.
This is used when printing the branches of [if ... then ... else ...].
- Rem.: this function will *fail* if there are {!Loop} nodes (you should call
+ Rem.: this function will *fail* if there are {!constructor:Aeneas.Pure.expression.Loop} nodes (you should call
it on an expression where those nodes have been eliminated).
*)
let rec let_group_requires_parentheses (e : texpression) : bool =
diff --git a/compiler/SCC.ml b/compiler/SCC.ml
index 889a972b..2095c350 100644
--- a/compiler/SCC.ml
+++ b/compiler/SCC.ml
@@ -83,7 +83,7 @@ module Make (Id : OrderedType) = struct
stable manner. For instance, if some Rust functions are mutually recursive,
it is possible that we can extract the forward functions in one group, and
extract the backward functions in one group (after filtering the useless
- calls in {!MicroPasses}), but is is also possible that all the functions
+ calls in {!module:PureMicroPasses}), but is is also possible that all the functions
(forward and backward) are mutually recursive). For this reason, we compute
the dependency graph and the strongly connected components of that graph.
Similar problems when functions contain loops (especially mutually recursive
diff --git a/compiler/SymbolicAst.ml b/compiler/SymbolicAst.ml
index 79865e73..4ecb194b 100644
--- a/compiler/SymbolicAst.ml
+++ b/compiler/SymbolicAst.ml
@@ -60,8 +60,68 @@ type meta =
(** We generated an assignment (destination, assigned value, src) *)
[@@deriving show]
-(** **Rk.:** here, {!expression} is not at all equivalent to the expressions
- used in LLBC: they are a first step towards lambda-calculus expressions.
+type variant_id = T.VariantId.id [@@deriving show]
+type global_decl_id = A.GlobalDeclId.id [@@deriving show]
+type 'a symbolic_value_id_map = 'a V.SymbolicValueId.Map.t [@@deriving show]
+type 'a region_group_id_map = 'a T.RegionGroupId.Map.t [@@deriving show]
+
+(** Ancestor for {!expression} iter visitor *)
+class ['self] iter_expression_base =
+ object (self : 'self)
+ inherit [_] VisitorsRuntime.iter
+ method visit_eval_ctx : 'env -> Contexts.eval_ctx -> unit = fun _ _ -> ()
+ method visit_typed_value : 'env -> V.typed_value -> unit = fun _ _ -> ()
+ method visit_call : 'env -> call -> unit = fun _ _ -> ()
+ method visit_abs : 'env -> V.abs -> unit = fun _ _ -> ()
+ method visit_loop_id : 'env -> V.loop_id -> unit = fun _ _ -> ()
+ method visit_variant_id : 'env -> variant_id -> unit = fun _ _ -> ()
+
+ method visit_symbolic_value_id : 'env -> V.symbolic_value_id -> unit =
+ fun _ _ -> ()
+
+ method visit_symbolic_value : 'env -> V.symbolic_value -> unit =
+ fun _ _ -> ()
+
+ method visit_region_group_id : 'env -> T.RegionGroupId.id -> unit =
+ fun _ _ -> ()
+
+ method visit_global_decl_id : 'env -> global_decl_id -> unit = fun _ _ -> ()
+ method visit_mplace : 'env -> mplace -> unit = fun _ _ -> ()
+ method visit_meta : 'env -> meta -> unit = fun _ _ -> ()
+
+ method visit_region_group_id_map
+ : 'a. ('env -> 'a -> unit) -> 'env -> 'a region_group_id_map -> unit =
+ fun f env m ->
+ T.RegionGroupId.Map.iter
+ (fun id x ->
+ self#visit_region_group_id env id;
+ f env x)
+ m
+
+ method visit_symbolic_value_id_map
+ : 'a. ('env -> 'a -> unit) -> 'env -> 'a symbolic_value_id_map -> unit =
+ fun f env m ->
+ V.SymbolicValueId.Map.iter
+ (fun id x ->
+ self#visit_symbolic_value_id env id;
+ f env x)
+ m
+
+ method visit_symbolic_value_id_set : 'env -> V.symbolic_value_id_set -> unit
+ =
+ fun env s ->
+ V.SymbolicValueId.Set.iter (self#visit_symbolic_value_id env) s
+
+ method visit_integer_type : 'env -> T.integer_type -> unit = fun _ _ -> ()
+ method visit_scalar_value : 'env -> V.scalar_value -> unit = fun _ _ -> ()
+
+ method visit_symbolic_expansion : 'env -> V.symbolic_expansion -> unit =
+ fun _ _ -> ()
+ end
+
+(** **Rem.:** here, {!expression} is not at all equivalent to the expressions
+ used in LLBC or in lambda-calculus: they are simply a first step towards
+ lambda-calculus expressions.
*)
type expression =
| Return of Contexts.eval_ctx * V.typed_value option
@@ -83,7 +143,7 @@ type expression =
The context is the evaluation context from after evaluating the asserted
value. It has the same purpose as for the {!Return} case.
*)
- | EvalGlobal of A.GlobalDeclId.id * V.symbolic_value * expression
+ | EvalGlobal of global_decl_id * V.symbolic_value * expression
(** Evaluate a global to a fresh symbolic value *)
| Assertion of Contexts.eval_ctx * V.typed_value * expression
(** An assertion.
@@ -101,9 +161,10 @@ type expression =
to prettify the generated code.
*)
| ForwardEnd of
- V.typed_value V.SymbolicValueId.Map.t option
+ Contexts.eval_ctx
+ * V.typed_value symbolic_value_id_map option
* expression
- * expression T.RegionGroupId.Map.t
+ * expression region_group_id_map
(** We use this delimiter to indicate at which point we switch to the
generation of code specific to the backward function(s). This allows
us in particular to factor the work out: we don't need to replay the
@@ -117,12 +178,28 @@ type expression =
The optional map from symbolic values to input values are input values
for loops: upon entering a loop, in the translation we call the loop
translation function, which takes care of the end of the execution.
+
+ The evaluation context is the context at the moment we introduce the
+ [ForwardEnd], and is used to translate the input values (see the
+ comments for the {!Return} variant).
*)
| Loop of loop (** Loop *)
+ | ReturnWithLoop of V.loop_id * bool
+ (** End the function with a call to a loop function.
+
+ This encompasses the cases when we synthesize a function body
+ and enter a loop for the first time, or when we synthesize a
+ loop body and reach a [Continue].
+
+ The boolean is [is_continue].
+ *)
| Meta of meta * expression (** Meta information *)
and loop = {
- loop_id : V.LoopId.id;
+ loop_id : V.loop_id;
+ input_svalues : V.symbolic_value list; (** The input symbolic values *)
+ fresh_svalues : V.symbolic_value_id_set;
+ (** The symbolic values introduced by the loop fixed-point *)
end_expr : expression;
(** The end of the function (upon the moment it enters the loop) *)
loop_expr : expression; (** The symbolically executed loop body *)
@@ -137,8 +214,7 @@ and expansion =
*Doesn't* include:
- expansion of ADTs with one variant
*)
- | ExpandAdt of
- (T.VariantId.id option * V.symbolic_value list * expression) list
+ | ExpandAdt of (variant_id option * V.symbolic_value list * expression) list
(** ADT expansion *)
| ExpandBool of expression * expression
(** A boolean expansion (i.e, an [if ... then ... else ...]) *)
@@ -146,4 +222,14 @@ and expansion =
T.integer_type * (V.scalar_value * expression) list * expression
(** An integer expansion (i.e, a switch over an integer). The last
expression is for the "otherwise" branch. *)
-[@@deriving show]
+[@@deriving
+ show,
+ visitors
+ {
+ name = "iter_expression";
+ variety = "iter";
+ ancestors = [ "iter_expression_base" ];
+ nude = true (* Don't inherit {!VisitorsRuntime.iter} *);
+ concrete = true;
+ polymorphic = false;
+ }]
diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml
index dd662074..006fdda7 100644
--- a/compiler/SymbolicToPure.ml
+++ b/compiler/SymbolicToPure.ml
@@ -65,6 +65,46 @@ type call_info = {
*)
}
+(** Contains information about a loop we entered.
+
+ Note that a path in a translated function body can have at most one call to
+ a loop, because the loop function takes care of the end of the execution
+ (and always happen at the end of the function). To be more precise, if we
+ translate a function body which contains a loop, one of the leaves will be a
+ call to the loop translation. The same happens for loop bodies.
+
+ For instance, if in Rust we have:
+ {[
+ fn get(...) {
+ let x = f(...);
+
+ loop {
+ ...
+ }
+ }
+ ]}
+
+ Then in the translation we have:
+ {[
+ let get_fwd ... =
+ let x = f_fwd ... in
+ (* We end the function by calling the loop translation *)
+ get_fwd_loop ...
+ ]}
+
+ The various input and output fields are for this unique loop call, if
+ there is one.
+ *)
+type loop_info = {
+ loop_id : LoopId.id;
+ input_svl : V.symbolic_value list;
+ type_args : ty list;
+ forward_inputs : texpression list option;
+ (** The forward inputs are initialized at [None] *)
+ forward_output_no_state : var option;
+ (** The forward outputs are initialized at [None] *)
+}
+
(** Body synthesis context *)
type bs_ctx = {
type_context : type_context;
@@ -119,7 +159,14 @@ type bs_ctx = {
(** The function calls we encountered so far *)
abstractions : (V.abs * texpression list) V.AbstractionId.Map.t;
(** The ended abstractions we encountered so far, with their additional input arguments *)
- loop_id : V.LoopId.id option;
+ loop_ids_map : LoopId.id V.LoopId.Map.t; (** Ids to use for the loops *)
+ loops : loop_info LoopId.Map.t;
+ (** The loops we encountered so far.
+
+ We are using a map to be general - in practice we will fail if we encounter
+ more than one loop on a single path.
+ *)
+ loop_id : LoopId.id option;
(** [Some] if we reached a loop (we are synthesizing a function, and reached a loop, or are
synthesizing the loop body itself)
*)
@@ -535,7 +582,8 @@ let mk_fuel_input_as_list (ctx : bs_ctx) (info : fun_effect_info) :
(** Small utility. *)
let get_fun_effect_info (fun_infos : FA.fun_info A.FunDeclId.Map.t)
- (fun_id : A.fun_id) (gid : T.RegionGroupId.id option) : fun_effect_info =
+ (fun_id : A.fun_id) (lid : V.LoopId.id option)
+ (gid : T.RegionGroupId.id option) : fun_effect_info =
match fun_id with
| A.Regular fid ->
let info = A.FunDeclId.Map.find fid fun_infos in
@@ -548,9 +596,10 @@ let get_fun_effect_info (fun_infos : FA.fun_info A.FunDeclId.Map.t)
stateful_group;
stateful;
can_diverge = info.can_diverge;
- is_rec = info.is_rec;
+ is_rec = info.is_rec || Option.is_some lid;
}
| A.Assumed aid ->
+ assert (lid = None);
{
can_fail = Assumed.assumed_can_fail aid;
stateful_group = false;
@@ -579,7 +628,8 @@ let translate_fun_sig (fun_infos : FA.fun_info A.FunDeclId.Map.t)
(Some bid, parents)
in
(* Is the function stateful, and can it fail? *)
- let effect_info = get_fun_effect_info fun_infos fun_id bid in
+ let lid = None in
+ let effect_info = get_fun_effect_info fun_infos fun_id lid bid in
(* List the inputs for:
* - the fuel
* - the forward function
@@ -728,28 +778,37 @@ let translate_fun_sig (fun_infos : FA.fun_info A.FunDeclId.Map.t)
let sg = { type_params; inputs; output; doutputs; info } in
{ sg; output_names }
-let bs_ctx_fresh_state_var (ctx : bs_ctx) : bs_ctx * typed_pattern =
+let bs_ctx_fresh_state_var (ctx : bs_ctx) : bs_ctx * var * typed_pattern =
(* Generate the fresh variable *)
let id, var_counter = VarId.fresh ctx.var_counter in
- let var =
+ let state_var =
{ id; basename = Some ConstStrings.state_basename; ty = mk_state_ty }
in
- let state_var = mk_typed_pattern_from_var var None in
+ let state_pat = mk_typed_pattern_from_var state_var None in
(* Update the context *)
let ctx = { ctx with var_counter; state_var = id } in
(* Return *)
- (ctx, state_var)
+ (ctx, state_var, state_pat)
-let fresh_named_var_for_symbolic_value (basename : string option)
- (sv : V.symbolic_value) (ctx : bs_ctx) : bs_ctx * var =
+let fresh_var (basename : string option) (ty : 'r T.ty) (ctx : bs_ctx) :
+ bs_ctx * var =
(* Generate the fresh variable *)
let id, var_counter = VarId.fresh ctx.var_counter in
- let ty = ctx_translate_fwd_ty ctx sv.sv_ty in
+ let ty = ctx_translate_fwd_ty ctx ty in
let var = { id; basename; ty } in
+ (* Update the context *)
+ let ctx = { ctx with var_counter } in
+ (* Return *)
+ (ctx, var)
+
+let fresh_named_var_for_symbolic_value (basename : string option)
+ (sv : V.symbolic_value) (ctx : bs_ctx) : bs_ctx * var =
+ (* Generate the fresh variable *)
+ let ctx, var = fresh_var basename sv.sv_ty ctx in
(* Insert in the map *)
let sv_to_var = V.SymbolicValueId.Map.add sv.sv_id var ctx.sv_to_var in
(* Update the context *)
- let ctx = { ctx with var_counter; sv_to_var } in
+ let ctx = { ctx with sv_to_var } in
(* Return *)
(ctx, var)
@@ -1136,9 +1195,13 @@ and aproj_to_given_back (mp : mplace option) (aproj : V.aproj) (ctx : bs_ctx) :
See [typed_avalue_to_given_back].
*)
-let abs_to_given_back (mpl : mplace option list) (abs : V.abs) (ctx : bs_ctx) :
- bs_ctx * typed_pattern list =
- let avalues = List.combine mpl abs.avalues in
+let abs_to_given_back (mpl : mplace option list option) (abs : V.abs)
+ (ctx : bs_ctx) : bs_ctx * typed_pattern list =
+ let avalues =
+ match mpl with
+ | None -> List.map (fun av -> (None, av)) abs.avalues
+ | Some mpl -> List.combine mpl abs.avalues
+ in
let ctx, values =
List.fold_left_map
(fun ctx (mp, av) -> typed_avalue_to_given_back mp av ctx)
@@ -1151,7 +1214,7 @@ let abs_to_given_back (mpl : mplace option list) (abs : V.abs) (ctx : bs_ctx) :
let abs_to_given_back_no_mp (abs : V.abs) (ctx : bs_ctx) :
bs_ctx * typed_pattern list =
let mpl = List.map (fun _ -> None) abs.avalues in
- abs_to_given_back mpl abs ctx
+ abs_to_given_back (Some mpl) abs ctx
(** Return the ordered list of the (transitive) parents of a given abstraction.
@@ -1167,6 +1230,8 @@ let get_abs_ancestors (ctx : bs_ctx) (abs : V.abs) (call_id : V.FunCallId.id) :
let rec translate_expression (e : S.expression) (ctx : bs_ctx) : texpression =
match e with
| S.Return (ectx, opt_v) -> translate_return ectx opt_v ctx
+ | ReturnWithLoop (loop_id, is_continue) ->
+ translate_return_with_loop loop_id is_continue ctx
| Panic -> translate_panic ctx
| FunCall (call, e) -> translate_function_call call e ctx
| EndAbstraction (ectx, abs, e) -> translate_end_abstraction ectx abs e ctx
@@ -1174,19 +1239,8 @@ let rec translate_expression (e : S.expression) (ctx : bs_ctx) : texpression =
| Assertion (ectx, v, e) -> translate_assertion ectx v e ctx
| Expansion (p, sv, exp) -> translate_expansion p sv exp ctx
| Meta (meta, e) -> translate_meta meta e ctx
- | ForwardEnd (loop_input_values, e, back_e) ->
- assert (loop_input_values = None);
- (* Update the current state with the additional state received by the backward
- function, if needs be, and lookup the proper expression *)
- let ctx, e =
- match ctx.bid with
- | None -> (ctx, e)
- | Some bid ->
- let ctx = { ctx with state_var = ctx.back_state_var } in
- let e = T.RegionGroupId.Map.find bid back_e in
- (ctx, e)
- in
- translate_expression e ctx
+ | ForwardEnd (ectx, loop_input_values, e, back_e) ->
+ translate_forward_end ectx loop_input_values e back_e ctx
| Loop loop -> translate_loop loop ctx
and translate_panic (ctx : bs_ctx) : texpression =
@@ -1206,12 +1260,12 @@ and translate_panic (ctx : bs_ctx) : texpression =
ret_v
else mk_result_fail_texpression_with_error_id error_failure_id output_ty
-(** [opt_v]: the value to return, in case we translate a forward function *)
+(** [opt_v]: the value to return, in case we translate a forward body *)
and translate_return (ectx : C.eval_ctx) (opt_v : V.typed_value option)
(ctx : bs_ctx) : texpression =
(* There are two cases:
- - either we are translating a forward function, in which case the optional
- value should be [Some] (it is the returned value)
+ - either we reach the return of a forward function or a forward loop body,
+ in which case the optional value should be [Some] (it is the returned value)
- or we are translating a backward function, in which case it should be [None]
*)
(* Compute the values that we should return *without the state and the result
@@ -1246,7 +1300,52 @@ and translate_return (ectx : C.eval_ctx) (opt_v : V.typed_value option)
else output
in
(* Wrap in a result - TODO: check effect_info.can_fail to not always wrap *)
- (* TODO: we should use a [Return] function *)
+ mk_result_return_texpression output
+
+and translate_return_with_loop (loop_id : V.LoopId.id) (is_continue : bool)
+ (ctx : bs_ctx) : texpression =
+ assert (is_continue = ctx.inside_loop);
+ let loop_id = V.LoopId.Map.find loop_id ctx.loop_ids_map in
+ assert (loop_id = Option.get ctx.loop_id);
+
+ (* Lookup the loop information *)
+ let loop_id = Option.get ctx.loop_id in
+ let loop_info = LoopId.Map.find loop_id ctx.loops in
+
+ (* There are two cases depending on whether we translate a backward function
+ or not.
+ *)
+ let output =
+ match ctx.bid with
+ | None ->
+ (* Forward *)
+ mk_texpression_from_var (Option.get loop_info.forward_output_no_state)
+ | Some bid ->
+ (* Backward *)
+ (* Group the variables in which we stored the values we need to give back.
+ * See the explanations for the [SynthInput] case in [translate_end_abstraction] *)
+ let backward_outputs =
+ T.RegionGroupId.Map.find bid ctx.backward_outputs
+ in
+ let field_values = List.map mk_texpression_from_var backward_outputs in
+ mk_simpl_tuple_texpression field_values
+ in
+
+ (* We may need to return a state
+ * - error-monad: Return x
+ * - state-error: Return (state, x)
+ * Note that the loop function and the parent function live in the same
+ * effect - in particular, one manipulates a state iff the other does
+ * the same.
+ * *)
+ let effect_info = ctx.sg.info.effect_info in
+ let output =
+ if effect_info.stateful then
+ let state_rvalue = mk_state_texpression ctx.state_var in
+ mk_simpl_tuple_texpression [ state_rvalue; output ]
+ else output
+ in
+ (* Wrap in a result - TODO: check effect_info.can_fail to not always wrap *)
mk_result_return_texpression output
and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) :
@@ -1272,18 +1371,18 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) :
(* Retrieve the effect information about this function (can fail,
* takes a state as input, etc.) *)
let effect_info =
- get_fun_effect_info ctx.fun_context.fun_infos fid None
+ get_fun_effect_info ctx.fun_context.fun_infos fid None None
in
- (* If the function is stateful:
- * - add the fuel
- * - add the state input argument
- * - generate a fresh state variable for the returned state
- *)
+ (* Depending on the function effects:
+ * - add the fuel
+ * - add the state input argument
+ * - generate a fresh state variable for the returned state
+ *)
let args, ctx, out_state =
let fuel = mk_fuel_input_as_list ctx effect_info in
if effect_info.stateful then
let state_var = mk_state_texpression ctx.state_var in
- let ctx, nstate_var = bs_ctx_fresh_state_var ctx in
+ let ctx, _, nstate_var = bs_ctx_fresh_state_var ctx in
(List.concat [ fuel; args; [ state_var ] ], ctx, Some nstate_var)
else (List.concat [ fuel; args ], ctx, None)
in
@@ -1375,80 +1474,281 @@ and translate_end_abstraction (ectx : C.eval_ctx) (abs : V.abs)
^ V.show_abs_kind abs.kind));
match abs.kind with
| V.SynthInput rg_id ->
- (* When we end an input abstraction, this input abstraction gets back
- * the borrows which it introduced in the context through the input
- * values: by listing those values, we get the values which are given
- * back by one of the backward functions we are synthesizing. *)
- (* Note that we don't support nested borrows for now: if we find
- * an ended synthesized input abstraction, it must be the one corresponding
- * to the backward function wer are synthesizing, it can't be the one
- * for a parent backward function.
- *)
- let bid = Option.get ctx.bid in
- assert (rg_id = bid);
-
- (* The translation is done as follows:
- * - for a given backward function, we choose a set of variables [v_i]
- * - when we detect the ended input abstraction which corresponds
- * to the backward function, and which consumed the values [consumed_i],
- * we introduce:
- * {[
+ translate_end_abstraction_synth_input ectx abs e ctx rg_id
+ | V.FunCall (call_id, rg_id) ->
+ translate_end_abstraction_fun_call ectx abs e ctx call_id rg_id
+ | V.SynthRet rg_id -> translate_end_abstraction_synth_ret ectx abs e ctx rg_id
+ | Loop (loop_id, rg_id, abs_kind) ->
+ translate_end_abstraction_loop ectx abs e ctx loop_id rg_id abs_kind
+
+and translate_end_abstraction_synth_input (ectx : C.eval_ctx) (abs : V.abs)
+ (e : S.expression) (ctx : bs_ctx) (rg_id : T.RegionGroupId.id) : texpression
+ =
+ (* When we end an input abstraction, this input abstraction gets back
+ * the borrows which it introduced in the context through the input
+ * values: by listing those values, we get the values which are given
+ * back by one of the backward functions we are synthesizing. *)
+ (* Note that we don't support nested borrows for now: if we find
+ * an ended synthesized input abstraction, it must be the one corresponding
+ * to the backward function wer are synthesizing, it can't be the one
+ * for a parent backward function.
+ *)
+ let bid = Option.get ctx.bid in
+ assert (rg_id = bid);
+
+ (* The translation is done as follows:
+ * - for a given backward function, we choose a set of variables [v_i]
+ * - when we detect the ended input abstraction which corresponds
+ * to the backward function, and which consumed the values [consumed_i],
+ * we introduce:
+ * {[
* let v_i = consumed_i in
* ...
- * ]}
- * Then, when we reach the [Return] node, we introduce:
- * {[
+ * ]}
+ * Then, when we reach the [Return] node, we introduce:
+ * {[
* (v_i)
- * ]}
- * *)
- (* First, get the given back variables *)
- let given_back_variables =
- T.RegionGroupId.Map.find bid ctx.backward_outputs
- in
- (* Get the list of values consumed by the abstraction upon ending *)
- let consumed_values = abs_to_consumed ctx ectx abs in
- (* Group the two lists *)
- let variables_values =
- List.combine given_back_variables consumed_values
- in
- (* Sanity check: the two lists match (same types) *)
- List.iter
- (fun (var, v) -> assert ((var : var).ty = (v : texpression).ty))
- variables_values;
- (* Translate the next expression *)
- let next_e = translate_expression e ctx in
- (* Generate the assignemnts *)
- let monadic = false in
- List.fold_right
- (fun (var, value) (e : texpression) ->
- mk_let monadic (mk_typed_pattern_from_var var None) value e)
- variables_values next_e
- | V.FunCall (call_id, rg_id) ->
- let call_info = V.FunCallId.Map.find call_id ctx.calls in
- let call = call_info.forward in
- let fun_id =
- match call.call_id with
- | S.Fun (fun_id, _) -> fun_id
- | Unop _ | Binop _ ->
- (* Those don't have backward functions *)
- raise (Failure "Unreachable")
- in
+ * ]}
+ * *)
+ (* First, get the given back variables *)
+ let given_back_variables =
+ T.RegionGroupId.Map.find bid ctx.backward_outputs
+ in
+ (* Get the list of values consumed by the abstraction upon ending *)
+ let consumed_values = abs_to_consumed ctx ectx abs in
+ (* Group the two lists *)
+ let variables_values = List.combine given_back_variables consumed_values in
+ (* Sanity check: the two lists match (same types) *)
+ List.iter
+ (fun (var, v) -> assert ((var : var).ty = (v : texpression).ty))
+ variables_values;
+ (* Translate the next expression *)
+ let next_e = translate_expression e ctx in
+ (* Generate the assignemnts *)
+ let monadic = false in
+ List.fold_right
+ (fun (var, value) (e : texpression) ->
+ mk_let monadic (mk_typed_pattern_from_var var None) value e)
+ variables_values next_e
+
+and translate_end_abstraction_fun_call (ectx : C.eval_ctx) (abs : V.abs)
+ (e : S.expression) (ctx : bs_ctx) (call_id : V.FunCallId.id)
+ (rg_id : T.RegionGroupId.id) : texpression =
+ let call_info = V.FunCallId.Map.find call_id ctx.calls in
+ let call = call_info.forward in
+ let fun_id =
+ match call.call_id with
+ | S.Fun (fun_id, _) -> fun_id
+ | Unop _ | Binop _ ->
+ (* Those don't have backward functions *)
+ raise (Failure "Unreachable")
+ in
+ let effect_info =
+ get_fun_effect_info ctx.fun_context.fun_infos fun_id None (Some rg_id)
+ in
+ let type_args = List.map (ctx_translate_fwd_ty ctx) call.type_params in
+ (* Retrieve the original call and the parent abstractions *)
+ let _forward, backwards = get_abs_ancestors ctx abs call_id in
+ (* Retrieve the values consumed when we called the forward function and
+ * ended the parent backward functions: those give us part of the input
+ * values (rem: for now, as we disallow nested lifetimes, there can't be
+ * parent backward functions).
+ * Note that the forward inputs **include the fuel and the input state**
+ * (if we use those). *)
+ let fwd_inputs = call_info.forward_inputs in
+ let back_ancestors_inputs =
+ List.concat (List.map (fun (_abs, args) -> args) backwards)
+ in
+ (* Retrieve the values consumed upon ending the loans inside this
+ * abstraction: those give us the remaining input values *)
+ let back_inputs = abs_to_consumed ctx ectx abs in
+ (* If the function is stateful:
+ * - add the state input argument
+ * - generate a fresh state variable for the returned state
+ *)
+ let back_state, ctx, nstate =
+ if effect_info.stateful then
+ let back_state = mk_state_texpression ctx.state_var in
+ let ctx, _, nstate = bs_ctx_fresh_state_var ctx in
+ ([ back_state ], ctx, Some nstate)
+ else ([], ctx, None)
+ in
+ (* Concatenate all the inpus *)
+ let inputs =
+ List.concat [ fwd_inputs; back_ancestors_inputs; back_inputs; back_state ]
+ in
+ (* Retrieve the values given back by this function: those are the output
+ * values. We rely on the fact that there are no nested borrows to use the
+ * meta-place information from the input values given to the forward function
+ * (we need to add [None] for the return avalue) *)
+ let output_mpl =
+ List.append (List.map translate_opt_mplace call.args_places) [ None ]
+ in
+ let ctx, outputs = abs_to_given_back (Some output_mpl) abs ctx in
+ (* Group the output values together: first the updated inputs *)
+ let output = mk_simpl_tuple_pattern outputs in
+ (* Add the returned state if the function is stateful *)
+ let output =
+ match nstate with
+ | None -> output
+ | Some nstate -> mk_simpl_tuple_pattern [ nstate; output ]
+ in
+ (* Sanity check: there is the proper number of inputs and outputs, and they have the proper type *)
+ let _ =
+ let inst_sg = get_instantiated_fun_sig fun_id (Some rg_id) type_args ctx in
+ log#ldebug
+ (lazy
+ ("\n- fun_id: " ^ A.show_fun_id fun_id ^ "\n- inputs ("
+ ^ string_of_int (List.length inputs)
+ ^ "): "
+ ^ String.concat ", " (List.map show_texpression inputs)
+ ^ "\n- inst_sg.inputs ("
+ ^ string_of_int (List.length inst_sg.inputs)
+ ^ "): "
+ ^ String.concat ", " (List.map show_ty inst_sg.inputs)));
+ List.iter
+ (fun (x, ty) -> assert ((x : texpression).ty = ty))
+ (List.combine inputs inst_sg.inputs);
+ log#ldebug
+ (lazy
+ ("\n- outputs: "
+ ^ string_of_int (List.length outputs)
+ ^ "\n- expected outputs: "
+ ^ string_of_int (List.length inst_sg.doutputs)));
+ List.iter
+ (fun (x, ty) -> assert ((x : typed_pattern).ty = ty))
+ (List.combine outputs inst_sg.doutputs)
+ in
+ (* Retrieve the function id, and register the function call in the context
+ * if necessary *)
+ let ctx, func =
+ bs_ctx_register_backward_call abs call_id rg_id back_inputs ctx
+ in
+ (* Translate the next expression *)
+ let next_e = translate_expression e ctx in
+ (* Put everything together *)
+ let args_mplaces = List.map (fun _ -> None) inputs in
+ let args =
+ List.map
+ (fun (arg, mp) -> mk_opt_mplace_texpression mp arg)
+ (List.combine inputs args_mplaces)
+ in
+ let input_tys = (List.map (fun (x : texpression) -> x.ty)) args in
+ let ret_ty =
+ if effect_info.can_fail then mk_result_ty output.ty else output.ty
+ in
+ let func_ty = mk_arrows input_tys ret_ty in
+ let func = { id = FunOrOp func; type_args } in
+ let func = { e = Qualif func; ty = func_ty } in
+ let call = mk_apps func args in
+ (* **Optimization**:
+ * =================
+ * We do a small optimization here: if the backward function doesn't
+ * have any output, we don't introduce any function call.
+ * See the comment in {!Config.filter_useless_monadic_calls}.
+ *
+ * TODO: use an option to disallow backward functions from updating the state.
+ * TODO: a backward function which only gives back shared borrows shouldn't
+ * update the state (state updates should only be used for mutable borrows,
+ * with objects like Rc for instance).
+ *)
+ if !Config.filter_useless_monadic_calls && outputs = [] && nstate = None then (
+ (* No outputs - we do a small sanity check: the backward function
+ * should have exactly the same number of inputs as the forward:
+ * this number can be different only if the forward function returned
+ * a value containing mutable borrows, which can't be the case... *)
+ assert (List.length inputs = List.length fwd_inputs);
+ next_e)
+ else mk_let effect_info.can_fail output call next_e
+
+and translate_end_abstraction_synth_ret (ectx : C.eval_ctx) (abs : V.abs)
+ (e : S.expression) (ctx : bs_ctx) (rg_id : T.RegionGroupId.id) : texpression
+ =
+ (* If we end the abstraction which consumed the return value of the function
+ we are synthesizing, we get back the borrows which were inside. Those borrows
+ are actually input arguments of the backward function we are synthesizing.
+ So we simply need to introduce proper let bindings.
+
+ For instance:
+ {[
+ fn id<'a>(x : &'a mut u32) -> &'a mut u32 {
+ x
+ }
+ ]}
+
+ Upon ending the return abstraction for 'a, we get back the borrow for [x].
+ This new value is the second argument of the backward function:
+ {[
+ let id_back x nx = nx
+ ]}
+
+ In practice, upon ending this abstraction we introduce a useless
+ let-binding:
+ {[
+ let id_back x nx =
+ let s = nx in // the name [s] is not important (only collision matters)
+ ...
+ ]}
+
+ This let-binding later gets inlined, during a micro-pass.
+ *)
+ (* First, retrieve the list of variables used for the inputs for the
+ * backward function *)
+ let inputs = T.RegionGroupId.Map.find rg_id ctx.backward_inputs in
+ (* Retrieve the values consumed upon ending the loans inside this
+ * abstraction: as there are no nested borrows, there should be none. *)
+ let consumed = abs_to_consumed ctx ectx abs in
+ assert (consumed = []);
+ (* Retrieve the values given back upon ending this abstraction - note that
+ * we don't provide meta-place information, because those assignments will
+ * be inlined anyway... *)
+ log#ldebug (lazy ("abs: " ^ abs_to_string ctx abs));
+ let ctx, given_back = abs_to_given_back_no_mp abs ctx in
+ (* Link the inputs to those given back values - note that this also
+ * checks we have the same number of values, of course *)
+ let given_back_inputs = List.combine given_back inputs in
+ (* Sanity check *)
+ List.iter
+ (fun ((given_back, input) : typed_pattern * var) ->
+ log#ldebug
+ (lazy
+ ("\n- given_back ty: "
+ ^ ty_to_string ctx given_back.ty
+ ^ "\n- sig input ty: " ^ ty_to_string ctx input.ty));
+ assert (given_back.ty = input.ty))
+ given_back_inputs;
+ (* Translate the next expression *)
+ let next_e = translate_expression e ctx in
+ (* Generate the assignments *)
+ let monadic = false in
+ List.fold_right
+ (fun (given_back, input_var) e ->
+ mk_let monadic given_back (mk_texpression_from_var input_var) e)
+ given_back_inputs next_e
+
+and translate_end_abstraction_loop (ectx : C.eval_ctx) (abs : V.abs)
+ (e : S.expression) (ctx : bs_ctx) (loop_id : V.LoopId.id)
+ (rg_id : T.RegionGroupId.id option) (abs_kind : V.loop_abs_kind) :
+ texpression =
+ let vloop_id = loop_id in
+ let loop_id = V.LoopId.Map.find loop_id ctx.loop_ids_map in
+ assert (loop_id = Option.get ctx.loop_id);
+ let rg_id = Option.get rg_id in
+ (* There are two cases depending on the [abs_kind] (whether this is a
+ synth input or a regular loop call) *)
+ match abs_kind with
+ | V.LoopSynthInput ->
+ (* Actually the same case as [SynthInput] *)
+ translate_end_abstraction_synth_input ectx abs e ctx rg_id
+ | V.LoopCall ->
+ let fun_id = A.Regular ctx.fun_decl.A.def_id in
let effect_info =
- get_fun_effect_info ctx.fun_context.fun_infos fun_id (Some rg_id)
- in
- let type_args = List.map (ctx_translate_fwd_ty ctx) call.type_params in
- (* Retrieve the original call and the parent abstractions *)
- let _forward, backwards = get_abs_ancestors ctx abs call_id in
- (* Retrieve the values consumed when we called the forward function and
- * ended the parent backward functions: those give us part of the input
- * values (rem: for now, as we disallow nested lifetimes, there can't be
- * parent backward functions).
- * Note that the forward inputs **include the fuel and the input state**
- * (if we use those). *)
- let fwd_inputs = call_info.forward_inputs in
- let back_ancestors_inputs =
- List.concat (List.map (fun (_abs, args) -> args) backwards)
+ get_fun_effect_info ctx.fun_context.fun_infos fun_id (Some vloop_id)
+ (Some rg_id)
in
+ let loop_info = LoopId.Map.find loop_id ctx.loops in
+ let type_args = loop_info.type_args in
+ let fwd_inputs = Option.get loop_info.forward_inputs in
(* Retrieve the values consumed upon ending the loans inside this
* abstraction: those give us the remaining input values *)
let back_inputs = abs_to_consumed ctx ectx abs in
@@ -1459,23 +1759,14 @@ and translate_end_abstraction (ectx : C.eval_ctx) (abs : V.abs)
let back_state, ctx, nstate =
if effect_info.stateful then
let back_state = mk_state_texpression ctx.state_var in
- let ctx, nstate = bs_ctx_fresh_state_var ctx in
+ let ctx, _, nstate = bs_ctx_fresh_state_var ctx in
([ back_state ], ctx, Some nstate)
else ([], ctx, None)
in
(* Concatenate all the inpus *)
- let inputs =
- List.concat
- [ fwd_inputs; back_ancestors_inputs; back_inputs; back_state ]
- in
- (* Retrieve the values given back by this function: those are the output
- * values. We rely on the fact that there are no nested borrows to use the
- * meta-place information from the input values given to the forward function
- * (we need to add [None] for the return avalue) *)
- let output_mpl =
- List.append (List.map translate_opt_mplace call.args_places) [ None ]
- in
- let ctx, outputs = abs_to_given_back output_mpl abs ctx in
+ let inputs = List.concat [ fwd_inputs; back_inputs; back_state ] in
+ (* Retrieve the values given back by this function *)
+ let ctx, outputs = abs_to_given_back None abs ctx in
(* Group the output values together: first the updated inputs *)
let output = mk_simpl_tuple_pattern outputs in
(* Add the returned state if the function is stateful *)
@@ -1484,39 +1775,6 @@ and translate_end_abstraction (ectx : C.eval_ctx) (abs : V.abs)
| None -> output
| Some nstate -> mk_simpl_tuple_pattern [ nstate; output ]
in
- (* Sanity check: the inputs and outputs have the proper number and the proper type *)
- let _ =
- let inst_sg =
- get_instantiated_fun_sig fun_id (Some rg_id) type_args ctx
- in
- log#ldebug
- (lazy
- ("\n- fun_id: " ^ A.show_fun_id fun_id ^ "\n- inputs ("
- ^ string_of_int (List.length inputs)
- ^ "): "
- ^ String.concat ", " (List.map show_texpression inputs)
- ^ "\n- inst_sg.inputs ("
- ^ string_of_int (List.length inst_sg.inputs)
- ^ "): "
- ^ String.concat ", " (List.map show_ty inst_sg.inputs)));
- List.iter
- (fun (x, ty) -> assert ((x : texpression).ty = ty))
- (List.combine inputs inst_sg.inputs);
- log#ldebug
- (lazy
- ("\n- outputs: "
- ^ string_of_int (List.length outputs)
- ^ "\n- expected outputs: "
- ^ string_of_int (List.length inst_sg.doutputs)));
- List.iter
- (fun (x, ty) -> assert ((x : typed_pattern).ty = ty))
- (List.combine outputs inst_sg.doutputs)
- in
- (* Retrieve the function id, and register the function call in the context
- * if necessary *)
- let ctx, func =
- bs_ctx_register_backward_call abs call_id rg_id back_inputs ctx
- in
(* Translate the next expression *)
let next_e = translate_expression e ctx in
(* Put everything together *)
@@ -1531,6 +1789,7 @@ and translate_end_abstraction (ectx : C.eval_ctx) (abs : V.abs)
if effect_info.can_fail then mk_result_ty output.ty else output.ty
in
let func_ty = mk_arrows input_tys ret_ty in
+ let func = Fun (FromLlbc (fun_id, Some loop_id, Some rg_id)) in
let func = { id = FunOrOp func; type_args } in
let func = { e = Qualif func; ty = func_ty } in
let call = mk_apps func args in
@@ -1543,7 +1802,7 @@ and translate_end_abstraction (ectx : C.eval_ctx) (abs : V.abs)
* TODO: use an option to disallow backward functions from updating the state.
* TODO: a backward function which only gives back shared borrows shouldn't
* update the state (state updates should only be used for mutable borrows,
- * with objects like Rc for instance.
+ * with objects like Rc for instance).
*)
if !Config.filter_useless_monadic_calls && outputs = [] && nstate = None
then (
@@ -1554,69 +1813,6 @@ and translate_end_abstraction (ectx : C.eval_ctx) (abs : V.abs)
assert (List.length inputs = List.length fwd_inputs);
next_e)
else mk_let effect_info.can_fail output call next_e
- | V.SynthRet rg_id ->
- (* If we end the abstraction which consumed the return value of the function
- we are synthesizing, we get back the borrows which were inside. Those borrows
- are actually input arguments of the backward function we are synthesizing.
- So we simply need to introduce proper let bindings.
-
- For instance:
- {[
- fn id<'a>(x : &'a mut u32) -> &'a mut u32 {
- x
- }
- ]}
-
- Upon ending the return abstraction for 'a, we get back the borrow for [x].
- This new value is the second argument of the backward function:
- {[
- let id_back x nx = nx
- ]}
-
- In practice, upon ending this abstraction we introduce a useless
- let-binding:
- {[
- let id_back x nx =
- let s = nx in // the name [s] is not important (only collision matters)
- ...
- ]}
-
- This let-binding later gets inlined, during a micro-pass.
- *)
- (* First, retrieve the list of variables used for the inputs for the
- * backward function *)
- let inputs = T.RegionGroupId.Map.find rg_id ctx.backward_inputs in
- (* Retrieve the values consumed upon ending the loans inside this
- * abstraction: as there are no nested borrows, there should be none. *)
- let consumed = abs_to_consumed ctx ectx abs in
- assert (consumed = []);
- (* Retrieve the values given back upon ending this abstraction - note that
- * we don't provide meta-place information, because those assignments will
- * be inlined anyway... *)
- log#ldebug (lazy ("abs: " ^ abs_to_string ctx abs));
- let ctx, given_back = abs_to_given_back_no_mp abs ctx in
- (* Link the inputs to those given back values - note that this also
- * checks we have the same number of values, of course *)
- let given_back_inputs = List.combine given_back inputs in
- (* Sanity check *)
- List.iter
- (fun ((given_back, input) : typed_pattern * var) ->
- log#ldebug
- (lazy
- ("\n- given_back ty: "
- ^ ty_to_string ctx given_back.ty
- ^ "\n- sig input ty: " ^ ty_to_string ctx input.ty));
- assert (given_back.ty = input.ty))
- given_back_inputs;
- (* Translate the next expression *)
- let next_e = translate_expression e ctx in
- (* Generate the assignments *)
- let monadic = false in
- List.fold_right
- (fun (given_back, input_var) e ->
- mk_let monadic given_back (mk_texpression_from_var input_var) e)
- given_back_inputs next_e
- | Loop _ -> raise (Failure "Unimplemented")
and translate_global_eval (gid : A.GlobalDeclId.id) (sval : V.symbolic_value)
(e : S.expression) (ctx : bs_ctx) : texpression =
@@ -1841,8 +2037,177 @@ and translate_expansion (p : S.mplace option) (sv : V.symbolic_value)
List.for_all (fun (br : match_branch) -> br.branch.ty = ty) branches);
{ e; ty }
+and translate_forward_end (ectx : C.eval_ctx)
+ (loop_input_values : V.typed_value S.symbolic_value_id_map option)
+ (e : S.expression) (back_e : S.expression S.region_group_id_map)
+ (ctx : bs_ctx) : texpression =
+ (* Update the current state with the additional state received by the backward
+ function, if needs be, and lookup the proper expression *)
+ let translate_end ctx =
+ (* Update the current state with the additional state received by the backward
+ function, if needs be, and lookup the proper expression *)
+ let ctx, e =
+ match ctx.bid with
+ | None -> (ctx, e)
+ | Some bid ->
+ let ctx = { ctx with state_var = ctx.back_state_var } in
+ let e = T.RegionGroupId.Map.find bid back_e in
+ (ctx, e)
+ in
+ translate_expression e ctx
+ in
+
+ (* If we entered/are entering a loop, we need to introduce a call to the
+ forward translation of the loop. *)
+ match loop_input_values with
+ | None ->
+ (* "Regular" case: not a loop *)
+ assert (ctx.loop_id = None);
+ translate_end ctx
+ | Some loop_input_values ->
+ (* Loop *)
+ let loop_id = Option.get ctx.loop_id in
+
+ (* Lookup the loop information *)
+ let loop_info = LoopId.Map.find loop_id ctx.loops in
+
+ (* Translate the input values *)
+ let loop_input_values =
+ List.map
+ (fun sv -> V.SymbolicValueId.Map.find sv.V.sv_id loop_input_values)
+ loop_info.input_svl
+ in
+ let args =
+ List.map (typed_value_to_texpression ctx ectx) loop_input_values
+ in
+
+ (* Lookup the effect info for the loop function *)
+ let fid = A.Regular ctx.fun_decl.A.def_id in
+ let effect_info =
+ get_fun_effect_info ctx.fun_context.fun_infos fid None ctx.bid
+ in
+
+ (* Introduce a fresh output value for the forward function *)
+ let ctx, output_var =
+ let output_ty = ctx.sg.output in
+ fresh_var None output_ty ctx
+ in
+ let args, ctx, out_pats =
+ let output_pat = mk_typed_pattern_from_var output_var None in
+
+ (* Depending on the function effects:
+ * - add the fuel
+ * - add the state input argument
+ * - generate a fresh state variable for the returned state
+ * TODO: we do exactly the same thing in {!translate_function_call}
+ *)
+ let fuel = mk_fuel_input_as_list ctx effect_info in
+ if effect_info.stateful then
+ let state_var = mk_state_texpression ctx.state_var in
+ let ctx, _nstate_var, nstate_pat = bs_ctx_fresh_state_var ctx in
+ ( List.concat [ fuel; args; [ state_var ] ],
+ ctx,
+ [ nstate_pat; output_pat ] )
+ else (List.concat [ fuel; args ], ctx, [ output_pat ])
+ in
+
+ (* Update the loop information in the context *)
+ let loop_info =
+ {
+ loop_info with
+ forward_inputs = Some args;
+ forward_output_no_state = Some output_var;
+ }
+ in
+ let ctx =
+ { ctx with loops = LoopId.Map.add loop_id loop_info ctx.loops }
+ in
+
+ (* Translate the end of the function *)
+ let next_e = translate_end ctx in
+
+ (* Introduce the call to the loop in the generated AST *)
+ let out_pat = mk_simpl_tuple_pattern out_pats in
+ let loop_call =
+ let fun_id = Fun (FromLlbc (fid, Some loop_id, None)) in
+ let func = { id = FunOrOp fun_id; type_args = loop_info.type_args } in
+ let input_tys = (List.map (fun (x : texpression) -> x.ty)) args in
+ let ret_ty =
+ if effect_info.can_fail then mk_result_ty out_pat.ty else out_pat.ty
+ in
+ let func_ty = mk_arrows input_tys ret_ty in
+ let func = { e = Qualif func; ty = func_ty } in
+ let call = mk_apps func args in
+ call
+ in
+ mk_let effect_info.can_fail out_pat loop_call next_e
+
and translate_loop (loop : S.loop) (ctx : bs_ctx) : texpression =
- raise (Failure "Unreachable")
+ let loop_id = V.LoopId.Map.find loop.loop_id ctx.loop_ids_map in
+
+ (* Translate the loop inputs *)
+ let inputs =
+ List.map
+ (fun sv -> V.SymbolicValueId.Map.find sv.V.sv_id ctx.sv_to_var)
+ loop.input_svalues
+ in
+ let inputs_lvs =
+ List.map (fun var -> mk_typed_pattern_from_var var None) inputs
+ in
+
+ (* Add the loop information in the context *)
+ let ctx =
+ assert (not (LoopId.Map.mem loop_id ctx.loops));
+
+ (* Note that we will retrieve the input values later in the [ForwardEnd]
+ (and will introduce the outputs at that moment, together with the actual
+ call to the loop forward function *)
+ let type_args =
+ List.map (fun ty -> TypeVar ty.T.index) ctx.sg.type_params
+ in
+
+ let loop_info =
+ {
+ loop_id;
+ input_svl = loop.input_svalues;
+ type_args;
+ forward_inputs = None;
+ forward_output_no_state = None;
+ }
+ in
+ let loops = LoopId.Map.add loop_id loop_info ctx.loops in
+ { ctx with loops }
+ in
+
+ (* Update the context to translate the function end *)
+ let ctx_end = { ctx with loop_id = Some loop_id } in
+ let fun_end = translate_expression loop.end_expr ctx_end in
+
+ (* Update the context for the loop body *)
+ let ctx_loop = { ctx_end with inside_loop = true } in
+ (* We also need to introduce variables for the symbolic values which are
+ introduced in the fixed point (we have to filter the list of symbolic
+ values, to remove the not fresh ones - the fixed point introduces some
+ symbolic values and keeps some others)... *)
+ let ctx_loop =
+ let svl =
+ List.filter
+ (fun (sv : V.symbolic_value) ->
+ V.SymbolicValueId.Set.mem sv.sv_id loop.fresh_svalues)
+ loop.input_svalues
+ in
+ let ctx_loop, _ = fresh_vars_for_symbolic_values svl ctx_loop in
+ ctx_loop
+ in
+
+ (* Translate the loop body *)
+ let loop_body = translate_expression loop.loop_expr ctx_loop in
+
+ (* Create the loop node and return *)
+ let loop = Loop { fun_end; loop_id; inputs; inputs_lvs; loop_body } in
+ assert (fun_end.ty = loop_body.ty);
+ let ty = fun_end.ty in
+ { e = loop; ty }
and translate_meta (meta : S.meta) (e : S.expression) (ctx : bs_ctx) :
texpression =
@@ -1947,7 +2312,8 @@ let translate_fun_decl (ctx : bs_ctx) (body : S.expression option) : fun_decl =
| None -> None
| Some body ->
let effect_info =
- get_fun_effect_info ctx.fun_context.fun_infos (Regular def_id) bid
+ get_fun_effect_info ctx.fun_context.fun_infos (Regular def_id) None
+ bid
in
let body = translate_expression body ctx in
(* Add a match over the fuel, if necessary *)
diff --git a/compiler/SynthesizeSymbolic.ml b/compiler/SynthesizeSymbolic.ml
index 4bb6529b..8c06717a 100644
--- a/compiler/SynthesizeSymbolic.ml
+++ b/compiler/SynthesizeSymbolic.ml
@@ -156,15 +156,17 @@ let synthesize_assertion (ctx : Contexts.eval_ctx) (v : V.typed_value)
(e : expression option) =
Option.map (fun e -> Assertion (ctx, v, e)) e
-let synthesize_forward_end
+let synthesize_forward_end (ctx : Contexts.eval_ctx)
(loop_input_values : V.typed_value V.SymbolicValueId.Map.t option)
(e : expression) (el : expression T.RegionGroupId.Map.t) =
- Some (ForwardEnd (loop_input_values, e, el))
+ Some (ForwardEnd (ctx, loop_input_values, e, el))
-let synthesize_loop (loop_id : V.LoopId.id) (end_expr : expression option)
+let synthesize_loop (loop_id : V.LoopId.id)
+ (input_svalues : V.symbolic_value list)
+ (fresh_svalues : V.SymbolicValueId.Set.t) (end_expr : expression option)
(loop_expr : expression option) : expression option =
match (end_expr, loop_expr) with
| None, None -> None
| Some end_expr, Some loop_expr ->
- Some (Loop { loop_id; end_expr; loop_expr })
+ Some (Loop { loop_id; input_svalues; fresh_svalues; end_expr; loop_expr })
| _ -> raise (Failure "Unreachable")
diff --git a/compiler/Translate.ml b/compiler/Translate.ml
index 75aeb37c..32c32ac4 100644
--- a/compiler/Translate.ml
+++ b/compiler/Translate.ml
@@ -92,6 +92,36 @@ let translate_function_to_pure (trans_ctx : trans_ctx)
let global_context =
{ SymbolicToPure.llbc_global_decls = global_context.global_decls }
in
+
+ (* Compute the set of loops, and find better ids for them (starting at 0).
+ Note that we only need to explore the forward function: the backward
+ functions should contain the same set of loops.
+ *)
+ let loop_ids_map =
+ match symbolic_trans with
+ | None -> V.LoopId.Map.empty
+ | Some (_, ast) ->
+ let m = ref V.LoopId.Map.empty in
+ let _, fresh_loop_id = Pure.LoopId.fresh_stateful_generator () in
+
+ let visitor =
+ object
+ inherit [_] SA.iter_expression as super
+
+ method! visit_loop env loop =
+ let _ =
+ match V.LoopId.Map.find_opt loop.loop_id !m with
+ | Some _ -> ()
+ | None ->
+ m := V.LoopId.Map.add loop.loop_id (fresh_loop_id ()) !m
+ in
+ super#visit_loop env loop
+ end
+ in
+ visitor#visit_expression () ast;
+ !m
+ in
+
let ctx =
{
SymbolicToPure.bid = None;
@@ -118,6 +148,8 @@ let translate_function_to_pure (trans_ctx : trans_ctx)
abstractions;
loop_id = None;
inside_loop = false;
+ loop_ids_map;
+ loops = Pure.LoopId.Map.empty;
}
in
diff --git a/compiler/Values.ml b/compiler/Values.ml
index f9b4e423..7d5ecc01 100644
--- a/compiler/Values.ml
+++ b/compiler/Values.ml
@@ -16,6 +16,8 @@ type big_int = PrimitiveValues.big_int [@@deriving show, ord]
type scalar_value = PrimitiveValues.scalar_value [@@deriving show, ord]
type primitive_value = PrimitiveValues.primitive_value [@@deriving show, ord]
type symbolic_value_id = SymbolicValueId.id [@@deriving show, ord]
+type symbolic_value_id_set = SymbolicValueId.Set.t [@@deriving show, ord]
+type loop_id = LoopId.id [@@deriving show, ord]
(** The kind of a symbolic value, which precises how the value was generated.
@@ -894,7 +896,14 @@ and typed_avalue = { value : avalue; ty : rty }
}]
(** TODO: make those variants of [abs_kind] *)
-type loop_abs_kind = LoopSynthInput | LoopSynthRet [@@deriving show, ord]
+type loop_abs_kind =
+ | LoopSynthInput
+ (** See {!abs_kind.SynthInput} - this abstraction is an input abstraction
+ for a loop body. *)
+ | LoopCall
+ (** An abstraction introduced because we (re-)entered a loop, that we see
+ like a function call. *)
+[@@deriving show, ord]
(** The kind of an abstraction, which keeps track of its origin *)
type abs_kind =