summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSon Ho2024-03-08 13:41:57 +0100
committerSon Ho2024-03-08 13:41:57 +0100
commitbc154dda94c44b3ae67a3b04d3866cc473aead32 (patch)
tree8d5eb4febc93e2f274a1918ea5353b9746f324d0
parentb604bb9935007a1f0e9c7f556f8196f0e14c85ce (diff)
Remove the option to split fwd/back functions and update SymbolicToPure
-rw-r--r--compiler/Config.ml63
-rw-r--r--compiler/ExtractBase.ml82
-rw-r--r--compiler/ExtractBuiltin.ml131
-rw-r--r--compiler/Main.ml3
-rw-r--r--compiler/SymbolicToPure.ml936
5 files changed, 439 insertions, 776 deletions
diff --git a/compiler/Config.ml b/compiler/Config.ml
index 3b0070c0..6fd866e8 100644
--- a/compiler/Config.ml
+++ b/compiler/Config.ml
@@ -92,69 +92,6 @@ let loop_fixed_point_max_num_iters = 2
(** {1 Translation} *)
-(** If true, do not define separate forward/backward functions, but make the
- forward functions return the backward function.
-
- Example:
- {[
- (* Rust *)
- pub fn list_nth<'a, T>(l: &'a mut List<T>, i: u32) -> &'a mut T {
- match l {
- List::Nil => {
- panic!()
- }
- List::Cons(x, tl) => {
- if i == 0 {
- x
- } else {
- list_nth(tl, i - 1)
- }
- }
- }
- }
-
- (* Translation, if return_back_funs = false *)
- def list_nth (T : Type) (l : List T) (i : U32) : Result T :=
- match l with
- | List.Cons x tl =>
- if i = 0#u32
- then Result.ret x
- else do
- let i0 ← i - 1#u32
- list_nth T tl i0
- | List.Nil => Result.fail .panic
-
- def list_nth_back
- (T : Type) (l : List T) (i : U32) (ret : T) : Result (List T) :=
- match l with
- | List.Cons x tl =>
- if i = 0#u32
- then Result.ret (List.Cons ret tl)
- else
- do
- let i0 ← i - 1#u32
- let tl0 ← list_nth_back T tl i0 ret
- Result.ret (List.Cons x tl0)
- | List.Nil => Result.fail .panic
-
- (* Translation, if return_back_funs = true *)
- def list_nth (T: Type) (ls : List T) (i : U32) :
- Result (T × (T → Result (List T))) :=
- match ls with
- | List.Cons x tl =>
- if i = 0#u32
- then Result.ret (x, (λ ret => return (ret :: ls)))
- else do
- let i0 ← i - 1#u32
- let (x, back) ← list_nth ls i0
- Return.ret (x,
- (λ ret => do
- let ls ← back ret
- return (x :: ls)))
- ]}
- *)
-let return_back_funs = ref true
-
(** Forbids using field projectors for structures.
If we don't use field projectors, whenever we symbolically expand a structure
diff --git a/compiler/ExtractBase.ml b/compiler/ExtractBase.ml
index 5aa8323e..04686705 100644
--- a/compiler/ExtractBase.ml
+++ b/compiler/ExtractBase.ml
@@ -471,8 +471,7 @@ type names_map_init = {
assumed_adts : (assumed_ty * string) list;
assumed_structs : (assumed_ty * string) list;
assumed_variants : (assumed_ty * VariantId.id * string) list;
- assumed_llbc_functions :
- (A.assumed_fun_id * RegionGroupId.id option * string) list;
+ assumed_llbc_functions : (A.assumed_fun_id * string) list;
assumed_pure_functions : (pure_assumed_fun_id * string) list;
}
@@ -1052,63 +1051,28 @@ let assumed_variants () : (assumed_ty * VariantId.id * string) list =
(* No Fuel::Succ on purpose *)
]
-let assumed_llbc_functions () :
- (A.assumed_fun_id * T.RegionGroupId.id option * string) list =
- let rg0 = Some T.RegionGroupId.zero in
- let regular : (A.assumed_fun_id * T.RegionGroupId.id option * string) list =
- match !backend with
- | FStar | Coq | HOL4 ->
- [
- (ArrayIndexShared, None, "array_index_usize");
- (ArrayToSliceShared, None, "array_to_slice");
- (ArrayRepeat, None, "array_repeat");
- (SliceIndexShared, None, "slice_index_usize");
- ]
- | Lean ->
- [
- (ArrayIndexShared, None, "Array.index_usize");
- (ArrayToSliceShared, None, "Array.to_slice");
- (ArrayRepeat, None, "Array.repeat");
- (SliceIndexShared, None, "Slice.index_usize");
- ]
- in
- let mut_funs : (A.assumed_fun_id * T.RegionGroupId.id option * string) list =
- if !Config.return_back_funs then
- match !backend with
- | FStar | Coq | HOL4 ->
- [
- (ArrayIndexMut, None, "array_index_mut_usize");
- (ArrayToSliceMut, None, "array_to_slice_mut");
- (SliceIndexMut, None, "slice_index_mut_usize");
- ]
- | Lean ->
- [
- (ArrayIndexMut, None, "Array.index_mut_usize");
- (ArrayToSliceMut, None, "Array.to_slice_mut");
- (SliceIndexMut, None, "Slice.index_mut_usize");
- ]
- else
- match !backend with
- | FStar | Coq | HOL4 ->
- [
- (ArrayIndexMut, None, "array_index_usize");
- (ArrayIndexMut, rg0, "array_update_usize");
- (ArrayToSliceMut, None, "array_to_slice");
- (ArrayToSliceMut, rg0, "array_from_slice");
- (SliceIndexMut, None, "slice_index_usize");
- (SliceIndexMut, rg0, "slice_update_usize");
- ]
- | Lean ->
- [
- (ArrayIndexMut, None, "Array.index_usize");
- (ArrayIndexMut, rg0, "Array.update_usize");
- (ArrayToSliceMut, None, "Array.to_slice");
- (ArrayToSliceMut, rg0, "Array.from_slice");
- (SliceIndexMut, None, "Slice.index_usize");
- (SliceIndexMut, rg0, "Slice.update_usize");
- ]
- in
- regular @ mut_funs
+let assumed_llbc_functions () : (A.assumed_fun_id * string) list =
+ match !backend with
+ | FStar | Coq | HOL4 ->
+ [
+ (ArrayIndexShared, "array_index_usize");
+ (ArrayIndexMut, "array_index_mut_usize");
+ (ArrayToSliceShared, "array_to_slice");
+ (ArrayToSliceMut, "array_to_slice_mut");
+ (ArrayRepeat, "array_repeat");
+ (SliceIndexShared, "slice_index_usize");
+ (SliceIndexMut, "slice_index_mut_usize");
+ ]
+ | Lean ->
+ [
+ (ArrayIndexShared, "Array.index_usize");
+ (ArrayIndexMut, "Array.index_mut_usize");
+ (ArrayToSliceShared, "Array.to_slice");
+ (ArrayToSliceMut, "Array.to_slice_mut");
+ (ArrayRepeat, "Array.repeat");
+ (SliceIndexShared, "Slice.index_usize");
+ (SliceIndexMut, "Slice.index_mut_usize");
+ ]
let assumed_pure_functions () : (pure_assumed_fun_id * string) list =
match !backend with
diff --git a/compiler/ExtractBuiltin.ml b/compiler/ExtractBuiltin.ml
index ee8d4831..3ea5655a 100644
--- a/compiler/ExtractBuiltin.ml
+++ b/compiler/ExtractBuiltin.ml
@@ -213,11 +213,7 @@ let mk_builtin_types_map () =
let builtin_types_map = mk_memoized mk_builtin_types_map
-type builtin_fun_info = {
- rg : Types.RegionGroupId.id option;
- extract_name : string;
-}
-[@@deriving show]
+type builtin_fun_info = { extract_name : string } [@@deriving show]
(** The assumed functions.
@@ -227,19 +223,10 @@ type builtin_fun_info = {
*)
let builtin_funs () : (pattern * bool list option * builtin_fun_info list) list
=
- let rg0 = Some Types.RegionGroupId.zero in
(* Small utility *)
let mk_fun (rust_name : string) (extract_name : string option)
- (filter : bool list option) (with_back : bool) (back_no_suffix : bool) :
+ (filter : bool list option) :
pattern * bool list option * builtin_fun_info list =
- (* [back_no_suffix] is used to control whether the backward function should
- have the suffix "_back" or not (if not, then the forward function has the
- prefix "_fwd", and is filtered anyway). This is pertinent only if we split
- the fwd/back functions. *)
- let back_no_suffix = back_no_suffix && not !Config.return_back_funs in
- (* Same for the [with_back] option: this is pertinent only if we split
- the fwd/back functions *)
- let with_back = with_back && not !Config.return_back_funs in
let rust_name =
try parse_pattern rust_name
with Failure _ ->
@@ -251,68 +238,51 @@ let builtin_funs () : (pattern * bool list option * builtin_fun_info list) list
| Some name -> split_on_separator name
in
let basename = flatten_name extract_name in
- let fwd_suffix = if with_back && back_no_suffix then "_fwd" else "" in
- let fwd = [ { rg = None; extract_name = basename ^ fwd_suffix } ] in
- let back_suffix = if with_back && back_no_suffix then "" else "_back" in
- let back =
- if with_back then [ { rg = rg0; extract_name = basename ^ back_suffix } ]
- else []
- in
- (rust_name, filter, fwd @ back)
+ let f = [ { extract_name = basename } ] in
+ (rust_name, filter, f)
in
[
- mk_fun "core::mem::replace" None None true false;
+ mk_fun "core::mem::replace" None None;
mk_fun "core::slice::{[@T]}::len"
(Some (backend_choice "slice::len" "Slice::len"))
- None true false;
+ None;
mk_fun "alloc::vec::{alloc::vec::Vec<@T, alloc::alloc::Global>}::new"
- (Some "alloc::vec::Vec::new") None false false;
+ (Some "alloc::vec::Vec::new") None;
mk_fun "alloc::vec::{alloc::vec::Vec<@T, @A>}::push" None
- (Some [ true; false ])
- true true;
+ (Some [ true; false ]);
mk_fun "alloc::vec::{alloc::vec::Vec<@T, @A>}::insert" None
- (Some [ true; false ])
- true true;
+ (Some [ true; false ]);
mk_fun "alloc::vec::{alloc::vec::Vec<@T, @A>}::len" None
- (Some [ true; false ])
- true false;
+ (Some [ true; false ]);
mk_fun "alloc::vec::{alloc::vec::Vec<@T, @A>}::index" None
- (Some [ true; true; false ])
- true false;
+ (Some [ true; true; false ]);
mk_fun "alloc::vec::{alloc::vec::Vec<@T, @A>}::index_mut" None
- (Some [ true; true; false ])
- true false;
- mk_fun "alloc::boxed::{Box<@T>}::deref" None
- (Some [ true; false ])
- true false;
- mk_fun "alloc::boxed::{Box<@T>}::deref_mut" None
- (Some [ true; false ])
- true false;
- mk_fun "core::slice::index::{[@T]}::index" None None true false;
- mk_fun "core::slice::index::{[@T]}::index_mut" None None true false;
- mk_fun "core::array::{[@T; @C]}::index" None None true false;
- mk_fun "core::array::{[@T; @C]}::index_mut" None None true false;
+ (Some [ true; true; false ]);
+ mk_fun "alloc::boxed::{Box<@T>}::deref" None (Some [ true; false ]);
+ mk_fun "alloc::boxed::{Box<@T>}::deref_mut" None (Some [ true; false ]);
+ mk_fun "core::slice::index::{[@T]}::index" None None;
+ mk_fun "core::slice::index::{[@T]}::index_mut" None None;
+ mk_fun "core::array::{[@T; @C]}::index" None None;
+ mk_fun "core::array::{[@T; @C]}::index_mut" None None;
mk_fun "core::slice::index::{core::ops::range::Range<usize>}::get"
- (Some "core::slice::index::RangeUsize::get") None true false;
+ (Some "core::slice::index::RangeUsize::get") None;
mk_fun "core::slice::index::{core::ops::range::Range<usize>}::get_mut"
- (Some "core::slice::index::RangeUsize::get_mut") None true false;
+ (Some "core::slice::index::RangeUsize::get_mut") None;
mk_fun "core::slice::index::{core::ops::range::Range<usize>}::index"
- (Some "core::slice::index::RangeUsize::index") None true false;
+ (Some "core::slice::index::RangeUsize::index") None;
mk_fun "core::slice::index::{core::ops::range::Range<usize>}::index_mut"
- (Some "core::slice::index::RangeUsize::index_mut") None true false;
+ (Some "core::slice::index::RangeUsize::index_mut") None;
mk_fun "core::slice::index::{core::ops::range::Range<usize>}::get_unchecked"
- (Some "core::slice::index::RangeUsize::get_unchecked") None false false;
+ (Some "core::slice::index::RangeUsize::get_unchecked") None;
mk_fun
"core::slice::index::{core::ops::range::Range<usize>}::get_unchecked_mut"
- (Some "core::slice::index::RangeUsize::get_unchecked_mut") None false
- false;
- mk_fun "core::slice::index::{usize}::get" None None true false;
- mk_fun "core::slice::index::{usize}::get_mut" None None true false;
- mk_fun "core::slice::index::{usize}::get_unchecked" None None false false;
- mk_fun "core::slice::index::{usize}::get_unchecked_mut" None None false
- false;
- mk_fun "core::slice::index::{usize}::index" None None true false;
- mk_fun "core::slice::index::{usize}::index_mut" None None true false;
+ (Some "core::slice::index::RangeUsize::get_unchecked_mut") None;
+ mk_fun "core::slice::index::{usize}::get" None None;
+ mk_fun "core::slice::index::{usize}::get_mut" None None;
+ mk_fun "core::slice::index::{usize}::get_unchecked" None None;
+ mk_fun "core::slice::index::{usize}::get_unchecked_mut" None None;
+ mk_fun "core::slice::index::{usize}::index" None None;
+ mk_fun "core::slice::index::{usize}::index_mut" None None;
]
let mk_builtin_funs_map () =
@@ -412,10 +382,9 @@ type builtin_trait_decl_info = {
[@@deriving show]
let builtin_trait_decls_info () =
- let rg0 = Some Types.RegionGroupId.zero in
let mk_trait (rust_name : string) ?(extract_name : string option = None)
?(parent_clauses : string list = []) ?(types : string list = [])
- ?(methods : (string * bool) list = []) () : builtin_trait_decl_info =
+ ?(methods : string list = []) () : builtin_trait_decl_info =
let rust_name = parse_pattern rust_name in
let extract_name =
match extract_name with
@@ -443,22 +412,14 @@ let builtin_trait_decls_info () =
List.map mk_type types
in
let methods =
- let mk_method (item_name, with_back) =
+ let mk_method item_name =
(* TODO: factor out with builtin_funs_info *)
let basename =
if !record_fields_short_names then item_name
else extract_name ^ "_" ^ item_name
in
- let back_no_suffix = false in
- let fwd_suffix = if with_back && back_no_suffix then "_fwd" else "" in
- let fwd = [ { rg = None; extract_name = basename ^ fwd_suffix } ] in
- let back_suffix = if with_back && back_no_suffix then "" else "_back" in
- let back =
- if with_back then
- [ { rg = rg0; extract_name = basename ^ back_suffix } ]
- else []
- in
- (item_name, fwd @ back)
+ let fwd = [ { extract_name = basename } ] in
+ (item_name, fwd)
in
List.map mk_method methods
in
@@ -474,21 +435,17 @@ let builtin_trait_decls_info () =
in
[
(* Deref *)
- mk_trait "core::ops::deref::Deref" ~types:[ "Target" ]
- ~methods:[ ("deref", true) ]
+ mk_trait "core::ops::deref::Deref" ~types:[ "Target" ] ~methods:[ "deref" ]
();
(* DerefMut *)
mk_trait "core::ops::deref::DerefMut" ~parent_clauses:[ "derefInst" ]
- ~methods:[ ("deref_mut", true) ]
- ();
+ ~methods:[ "deref_mut" ] ();
(* Index *)
- mk_trait "core::ops::index::Index" ~types:[ "Output" ]
- ~methods:[ ("index", true) ]
+ mk_trait "core::ops::index::Index" ~types:[ "Output" ] ~methods:[ "index" ]
();
(* IndexMut *)
mk_trait "core::ops::index::IndexMut" ~parent_clauses:[ "indexInst" ]
- ~methods:[ ("index_mut", true) ]
- ();
+ ~methods:[ "index_mut" ] ();
(* Sealed *)
mk_trait "core::slice::index::private_slice_index::Sealed" ();
(* SliceIndex *)
@@ -496,12 +453,12 @@ let builtin_trait_decls_info () =
~types:[ "Output" ]
~methods:
[
- ("get", true);
- ("get_mut", true);
- ("get_unchecked", false);
- ("get_unchecked_mut", false);
- ("index", true);
- ("index_mut", true);
+ "get";
+ "get_mut";
+ "get_unchecked";
+ "get_unchecked_mut";
+ "index";
+ "index_mut";
]
();
]
diff --git a/compiler/Main.ml b/compiler/Main.ml
index 4a2d01dc..664ec067 100644
--- a/compiler/Main.ml
+++ b/compiler/Main.ml
@@ -120,9 +120,6 @@ let () =
" Generate a default lakefile.lean (Lean only)" );
("-print-llbc", Arg.Set print_llbc, " Print the imported LLBC");
("-k", Arg.Clear fail_hard, " Do not fail hard in case of error");
- ( "-split-fwd-back",
- Arg.Clear return_back_funs,
- " Split the forward and backward functions." );
( "-tuple-nested-proj",
Arg.Set use_nested_tuple_projectors,
" Use nested projectors for tuples (e.g., (0, 1).snd.fst instead of \
diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml
index 3a50e495..859d6f17 100644
--- a/compiler/SymbolicToPure.ml
+++ b/compiler/SymbolicToPure.ml
@@ -805,11 +805,9 @@ let bs_ctx_register_forward_call (call_id : V.FunCallId.id) (forward : S.call)
that we need to call. This function may be [None] if it has to be ignored
(because it does nothing).
*)
-let bs_ctx_register_backward_call (abs : V.abs) (effect_info : fun_effect_info)
- (call_id : V.FunCallId.id) (back_id : T.RegionGroupId.id)
- (inherited_args : texpression list) (back_args : texpression list)
- (generics : generic_args) (output_ty : ty) (ctx : bs_ctx) :
- bs_ctx * texpression option =
+let bs_ctx_register_backward_call (abs : V.abs) (call_id : V.FunCallId.id)
+ (back_id : T.RegionGroupId.id) (back_args : texpression list) (ctx : bs_ctx)
+ : bs_ctx * texpression option =
(* Insert the abstraction in the call informations *)
let info = V.FunCallId.Map.find call_id ctx.calls in
let calls = V.FunCallId.Map.add call_id info ctx.calls in
@@ -819,29 +817,9 @@ let bs_ctx_register_backward_call (abs : V.abs) (effect_info : fun_effect_info)
let abstractions =
V.AbstractionId.Map.add abs.abs_id (abs, back_args) abstractions
in
- (* Compute the expression corresponding to the function *)
- let func =
- if !Config.return_back_funs then
- (* Lookup the variable introduced for the backward function *)
- RegionGroupId.Map.find back_id (Option.get info.back_funs)
- else
- (* Retrieve the fun_id *)
- let fun_id =
- match info.forward.call_id with
- | S.Fun (fid, _) ->
- let fid = translate_fun_id_or_trait_method_ref ctx fid in
- Fun (FromLlbc (fid, None, Some back_id))
- | S.Unop _ | S.Binop _ -> raise (Failure "Unreachable")
- in
- let args = List.append inherited_args back_args in
- let input_tys = (List.map (fun (x : texpression) -> x.ty)) args in
- let ret_ty =
- if effect_info.can_fail then mk_result_ty output_ty else output_ty
- in
- let func_ty = mk_arrows input_tys ret_ty in
- let func = { id = FunOrOp fun_id; generics } in
- Some { e = Qualif func; ty = func_ty }
- in
+ (* Compute the expression corresponding to the function.
+ We simply lookup the variable introduced for the backward function. *)
+ let func = RegionGroupId.Map.find back_id (Option.get info.back_funs) in
(* Update the context and return *)
({ ctx with calls; abstractions }, func)
@@ -1124,20 +1102,34 @@ let translate_fun_sig_with_regions_hierarchy_to_decomposed
let inputs_no_state =
List.map (fun ty -> (Some "ret", ty)) inputs_no_state
in
- (* In case we merge the forward/backward functions:
- we consider the backward function as stateful and potentially failing
+ (* We consider a backward function as stateful and potentially failing
**only if it has inputs** (for the "potentially failing": if it has
not inputs, we directly evaluate it in the body of the forward function).
+
+ For instance, we do the following:
+ {[
+ // Rust
+ fn push<T, 'a>(v : &mut Vec<T>, x : T) { ... }
+
+ (* Generated code: before doing unit elimination.
+ We return (), as well as the backward function; as the backward
+ function doesn't consume any inputs, it is a value that we compute
+ directly in the body of [push].
+ *)
+ let push T (v : Vec T) (x : T) : Result (() * Vec T) = ...
+
+ (* Generated code: after doing unit elimination, if we simplify the merged
+ fwd/back functions (see below). *)
+ let push T (v : Vec T) (x : T) : Result (Vec T) = ...
+ ]}
*)
let back_effect_info =
- if !Config.return_back_funs then
- let b = inputs_no_state <> [] in
- {
- back_effect_info with
- stateful = back_effect_info.stateful && b;
- can_fail = back_effect_info.can_fail && b;
- }
- else back_effect_info
+ let b = inputs_no_state <> [] in
+ {
+ back_effect_info with
+ stateful = back_effect_info.stateful && b;
+ can_fail = back_effect_info.can_fail && b;
+ }
in
let state =
if back_effect_info.stateful then [ (None, mk_state_ty) ] else []
@@ -1145,8 +1137,7 @@ let translate_fun_sig_with_regions_hierarchy_to_decomposed
let inputs = inputs_no_state @ state in
let output_names, outputs = compute_back_outputs_for_gid gid in
let filter =
- !Config.simplify_merged_fwd_backs
- && !Config.return_back_funs && inputs = [] && outputs = []
+ !Config.simplify_merged_fwd_backs && inputs = [] && outputs = []
in
let info =
{
@@ -1186,7 +1177,7 @@ let translate_fun_sig_with_regions_hierarchy_to_decomposed
}
in
let ignore_output =
- if !Config.return_back_funs && !Config.simplify_merged_fwd_backs then
+ if !Config.simplify_merged_fwd_backs then
ty_is_unit fwd_output
&& List.exists
(fun (info : back_sg_info) -> not info.filter)
@@ -1296,10 +1287,10 @@ let compute_back_tys (dsg : Pure.decomposed_fun_sig)
(subst : (generic_args * trait_instance_id) option) : ty option list =
List.map (Option.map snd) (compute_back_tys_with_info dsg subst)
-(** In case we merge the fwd/back functions: compute the output type of
- a function, from a decomposed signature. *)
+(** Compute the output type of a function, from a decomposed signature
+ (the output type contains the type of the value returned by the forward
+ function as well as the types of the returned backward functions). *)
let compute_output_ty_from_decomposed (dsg : Pure.decomposed_fun_sig) : ty =
- assert !Config.return_back_funs;
(* Compute the arrow types for all the backward functions *)
let back_tys = List.filter_map (fun x -> x) (compute_back_tys dsg None) in
(* Group the forward output and the types of the backward functions *)
@@ -1315,8 +1306,8 @@ let compute_output_ty_from_decomposed (dsg : Pure.decomposed_fun_sig) : ty =
in
mk_output_ty_from_effect_info effect_info output
-let translate_fun_sig_from_decomposed (dsg : Pure.decomposed_fun_sig)
- (gid : RegionGroupId.id option) : fun_sig =
+let translate_fun_sig_from_decomposed (dsg : Pure.decomposed_fun_sig) : fun_sig
+ =
let generics = dsg.generics in
let llbc_generics = dsg.llbc_generics in
let preds = dsg.preds in
@@ -1329,27 +1320,10 @@ let translate_fun_sig_from_decomposed (dsg : Pure.decomposed_fun_sig)
(gid, info.effect_info))
(RegionGroupId.Map.bindings dsg.back_sg))
in
- let mk_output_ty = mk_output_ty_from_effect_info in
let inputs, output =
- (* Two cases depending on whether we split the forward/backward functions or not *)
- if !Config.return_back_funs then (
- assert (gid = None);
- let output = compute_output_ty_from_decomposed dsg in
- let inputs = dsg.fwd_inputs in
- (inputs, output))
- else
- match gid with
- | None ->
- let effect_info = dsg.fwd_info.effect_info in
- let output = mk_output_ty effect_info dsg.fwd_output in
- (dsg.fwd_inputs, output)
- | Some gid ->
- let back_sg = RegionGroupId.Map.find gid dsg.back_sg in
- let effect_info = back_sg.effect_info in
- let inputs = dsg.fwd_inputs @ List.map snd back_sg.inputs in
- let output = mk_simpl_tuple_ty back_sg.outputs in
- let output = mk_output_ty effect_info output in
- (inputs, output)
+ let output = compute_output_ty_from_decomposed dsg in
+ let inputs = dsg.fwd_inputs in
+ (inputs, output)
in
{ generics; llbc_generics; preds; inputs; output; fwd_info; back_effect_info }
@@ -1933,16 +1907,14 @@ and translate_panic (ctx : bs_ctx) : texpression =
*)
match ctx.bid with
| None ->
- if !Config.return_back_funs then
- let back_tys = compute_back_tys ctx.sg None in
- let back_tys = List.filter_map (fun x -> x) back_tys in
- let tys =
- if ctx.sg.fwd_info.ignore_output then back_tys
- else ctx.sg.fwd_output :: back_tys
- in
- let output = mk_simpl_tuple_ty tys in
- mk_output output
- else mk_output ctx.sg.fwd_output
+ let back_tys = compute_back_tys ctx.sg None in
+ let back_tys = List.filter_map (fun x -> x) back_tys in
+ let tys =
+ if ctx.sg.fwd_info.ignore_output then back_tys
+ else ctx.sg.fwd_output :: back_tys
+ in
+ let output = mk_simpl_tuple_ty tys in
+ mk_output output
| Some bid ->
let output =
mk_simpl_tuple_ty (RegionGroupId.Map.find bid ctx.sg.back_sg).outputs
@@ -2080,107 +2052,103 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) :
(List.concat [ fuel; args; [ state_var ] ], ctx, Some nstate_var)
else (List.concat [ fuel; args ], ctx, None)
in
- (* If we do not split the forward/backward functions: generate the
- variables for the backward functions returned by the forward
+ (* Generate the variables for the backward functions returned by the forward
function. *)
let ctx, ignore_fwd_output, back_funs_map, back_funs =
- if !Config.return_back_funs then (
- (* We need to compute the signatures of the backward functions. *)
- let sg = Option.get call.sg in
- let decls_ctx = ctx.decls_ctx in
- let dsg =
- translate_fun_sig_with_regions_hierarchy_to_decomposed decls_ctx
- fid call.regions_hierarchy sg
- (List.map (fun _ -> None) sg.inputs)
- in
- log#ldebug
- (lazy ("dsg.generics:\n" ^ show_generic_params dsg.generics));
- let tr_self, all_generics =
- match call.trait_method_generics with
- | None -> (UnknownTrait __FUNCTION__, generics)
- | Some (all_generics, tr_self) ->
- let all_generics =
- ctx_translate_fwd_generic_args ctx all_generics
- in
- let tr_self =
- translate_fwd_trait_instance_id ctx.type_ctx.type_infos
- tr_self
+ (* We need to compute the signatures of the backward functions. *)
+ let sg = Option.get call.sg in
+ let decls_ctx = ctx.decls_ctx in
+ let dsg =
+ translate_fun_sig_with_regions_hierarchy_to_decomposed decls_ctx fid
+ call.regions_hierarchy sg
+ (List.map (fun _ -> None) sg.inputs)
+ in
+ log#ldebug
+ (lazy ("dsg.generics:\n" ^ show_generic_params dsg.generics));
+ let tr_self, all_generics =
+ match call.trait_method_generics with
+ | None -> (UnknownTrait __FUNCTION__, generics)
+ | Some (all_generics, tr_self) ->
+ let all_generics =
+ ctx_translate_fwd_generic_args ctx all_generics
+ in
+ let tr_self =
+ translate_fwd_trait_instance_id ctx.type_ctx.type_infos
+ tr_self
+ in
+ (tr_self, all_generics)
+ in
+ let back_tys =
+ compute_back_tys_with_info dsg (Some (all_generics, tr_self))
+ in
+ (* Introduce variables for the backward functions *)
+ (* Compute a proper basename for the variables *)
+ let back_fun_name =
+ let name =
+ match fid with
+ | FunId (FAssumed fid) -> (
+ match fid with
+ | BoxNew -> "box_new"
+ | BoxFree -> "box_free"
+ | ArrayRepeat -> "array_repeat"
+ | ArrayIndexShared -> "index_shared"
+ | ArrayIndexMut -> "index_mut"
+ | ArrayToSliceShared -> "to_slice_shared"
+ | ArrayToSliceMut -> "to_slice_mut"
+ | SliceIndexShared -> "index_shared"
+ | SliceIndexMut -> "index_mut")
+ | FunId (FRegular fid) | TraitMethod (_, _, fid) -> (
+ let decl =
+ FunDeclId.Map.find fid ctx.fun_ctx.llbc_fun_decls
in
- (tr_self, all_generics)
+ match Collections.List.last decl.name with
+ | PeIdent (s, _) -> s
+ | PeImpl _ ->
+ (* We shouldn't get there *)
+ raise (Failure "Unexpected"))
in
- let back_tys =
- compute_back_tys_with_info dsg (Some (all_generics, tr_self))
- in
- (* Introduce variables for the backward functions *)
- (* Compute a proper basename for the variables *)
- let back_fun_name =
- let name =
- match fid with
- | FunId (FAssumed fid) -> (
- match fid with
- | BoxNew -> "box_new"
- | BoxFree -> "box_free"
- | ArrayRepeat -> "array_repeat"
- | ArrayIndexShared -> "index_shared"
- | ArrayIndexMut -> "index_mut"
- | ArrayToSliceShared -> "to_slice_shared"
- | ArrayToSliceMut -> "to_slice_mut"
- | SliceIndexShared -> "index_shared"
- | SliceIndexMut -> "index_mut")
- | FunId (FRegular fid) | TraitMethod (_, _, fid) -> (
- let decl =
- FunDeclId.Map.find fid ctx.fun_ctx.llbc_fun_decls
- in
- match Collections.List.last decl.name with
- | PeIdent (s, _) -> s
- | PeImpl _ ->
- (* We shouldn't get there *)
- raise (Failure "Unexpected"))
- in
- name ^ "_back"
- in
- let ctx, back_vars =
- fresh_opt_vars
- (List.map
- (fun ty ->
- match ty with
- | None -> None
- | Some (back_sg, ty) ->
- (* We insert a name for the variable only if the function
- can fail: if it can fail, it means the call returns a backward
- function. Otherwise, we it directly returns the value given
- back by the backward function, which means we shouldn't
- give it a name like "back..." (it doesn't make sense) *)
- let name =
- if back_sg.effect_info.can_fail then
- Some back_fun_name
- else None
- in
- Some (name, ty))
- back_tys)
- ctx
- in
- let back_funs =
- List.filter_map
- (fun v ->
- match v with
- | None -> None
- | Some v -> Some (mk_typed_pattern_from_var v None))
- back_vars
- in
- let gids =
- List.map
- (fun (g : T.region_var_group) -> g.id)
- call.regions_hierarchy
- in
- let back_vars =
- List.map (Option.map mk_texpression_from_var) back_vars
- in
- let back_funs_map =
- RegionGroupId.Map.of_list (List.combine gids back_vars)
- in
- (ctx, dsg.fwd_info.ignore_output, Some back_funs_map, back_funs))
- else (ctx, false, None, [])
+ name ^ "_back"
+ in
+ let ctx, back_vars =
+ fresh_opt_vars
+ (List.map
+ (fun ty ->
+ match ty with
+ | None -> None
+ | Some (back_sg, ty) ->
+ (* We insert a name for the variable only if the function
+ can fail: if it can fail, it means the call returns a backward
+ function. Otherwise, we it directly returns the value given
+ back by the backward function, which means we shouldn't
+ give it a name like "back..." (it doesn't make sense) *)
+ let name =
+ if back_sg.effect_info.can_fail then Some back_fun_name
+ else None
+ in
+ Some (name, ty))
+ back_tys)
+ ctx
+ in
+ let back_funs =
+ List.filter_map
+ (fun v ->
+ match v with
+ | None -> None
+ | Some v -> Some (mk_typed_pattern_from_var v None))
+ back_vars
+ in
+ let gids =
+ List.map
+ (fun (g : T.region_var_group) -> g.id)
+ call.regions_hierarchy
+ in
+ let back_vars =
+ List.map (Option.map mk_texpression_from_var) back_vars
+ in
+ let back_funs_map =
+ RegionGroupId.Map.of_list (List.combine gids back_vars)
+ in
+ (ctx, dsg.fwd_info.ignore_output, Some back_funs_map, back_funs)
in
(* Compute the pattern for the destination *)
let ctx, dest = fresh_var_for_symbolic_value call.dest ctx in
@@ -2407,19 +2375,6 @@ and translate_end_abstraction_fun_call (ectx : C.eval_ctx) (abs : V.abs)
raise (Failure "Unreachable")
in
let effect_info = get_fun_effect_info ctx fun_id None (Some rg_id) in
- let generics = ctx_translate_fwd_generic_args ctx call.generics in
- (* Retrieve the original call and the parent abstractions *)
- let _forward, backwards = get_abs_ancestors ctx abs call_id in
- (* Retrieve the values consumed when we called the forward function and
- * ended the parent backward functions: those give us part of the input
- * values (rem: for now, as we disallow nested lifetimes, there can't be
- * parent backward functions).
- * Note that the forward inputs **include the fuel and the input state**
- * (if we use those). *)
- let fwd_inputs = call_info.forward_inputs in
- let back_ancestors_inputs =
- List.concat (List.map (fun (_abs, args) -> args) backwards)
- in
(* Retrieve the values consumed upon ending the loans inside this
* abstraction: those give us the remaining input values *)
let back_inputs = abs_to_consumed ctx ectx abs in
@@ -2434,11 +2389,6 @@ and translate_end_abstraction_fun_call (ectx : C.eval_ctx) (abs : V.abs)
([ back_state ], ctx, Some nstate)
else ([], ctx, None)
in
- (* Concatenate all the inpus *)
- let inherited_inputs =
- if !Config.return_back_funs then []
- else List.concat [ fwd_inputs; back_ancestors_inputs ]
- in
let back_inputs = List.append back_inputs back_state in
(* Retrieve the values given back by this function: those are the output
* values. We rely on the fact that there are no nested borrows to use the
@@ -2459,58 +2409,33 @@ and translate_end_abstraction_fun_call (ectx : C.eval_ctx) (abs : V.abs)
(* Retrieve the function id, and register the function call in the context
if necessary.Arith_status *)
let ctx, func =
- bs_ctx_register_backward_call abs effect_info call_id rg_id inherited_inputs
- back_inputs generics output.ty ctx
+ bs_ctx_register_backward_call abs call_id rg_id back_inputs ctx
in
(* Translate the next expression *)
let next_e = translate_expression e ctx in
(* Put everything together *)
- let inputs = List.append inherited_inputs back_inputs in
+ let inputs = back_inputs in
let args_mplaces = List.map (fun _ -> None) inputs in
let args =
List.map
(fun (arg, mp) -> mk_opt_mplace_texpression mp arg)
(List.combine inputs args_mplaces)
in
- (* **Optimization**:
- =================
- We do a small optimization here if we split the forward/backward functions.
- If the backward function doesn't have any output, we don't introduce any function
- call.
- See the comment in {!Config.filter_useless_monadic_calls}.
-
- TODO: use an option to disallow backward functions from updating the state.
- TODO: a backward function which only gives back shared borrows shouldn't
- update the state (state updates should only be used for mutable borrows,
- with objects like Rc for instance).
- *)
- if
- (not !Config.return_back_funs)
- && !Config.filter_useless_monadic_calls
- && outputs = [] && nstate = None
- then (
- (* No outputs - we do a small sanity check: the backward function
- should have exactly the same number of inputs as the forward:
- this number can be different only if the forward function returned
- a value containing mutable borrows, which can't be the case... *)
- assert (List.length inputs = List.length fwd_inputs);
- next_e)
- else
- (* The backward function might also have been filtered if we do not
- split the forward/backward functions *)
- match func with
- | None -> next_e
- | Some func ->
- log#ldebug
- (lazy
- (let args = List.map (texpression_to_string ctx) args in
- "func: "
- ^ texpression_to_string ctx func
- ^ "\nfunc type: "
- ^ pure_ty_to_string ctx func.ty
- ^ "\n\nargs:\n" ^ String.concat "\n" args));
- let call = mk_apps func args in
- mk_let effect_info.can_fail output call next_e
+ (* The backward function might have been filtered it does nothing
+ (consumes unit and returns unit). *)
+ match func with
+ | None -> next_e
+ | Some func ->
+ log#ldebug
+ (lazy
+ (let args = List.map (texpression_to_string ctx) args in
+ "func: "
+ ^ texpression_to_string ctx func
+ ^ "\nfunc type: "
+ ^ pure_ty_to_string ctx func.ty
+ ^ "\n\nargs:\n" ^ String.concat "\n" args));
+ let call = mk_apps func args in
+ mk_let effect_info.can_fail output call next_e
and translate_end_abstraction_identity (ectx : C.eval_ctx) (abs : V.abs)
(e : S.expression) (ctx : bs_ctx) : texpression =
@@ -2637,10 +2562,7 @@ and translate_end_abstraction_loop (ectx : C.eval_ctx) (abs : V.abs)
else ([], ctx, None)
in
(* Concatenate all the inputs *)
- let inputs =
- if !Config.return_back_funs then List.concat [ back_inputs; back_state ]
- else List.concat [ fwd_inputs; back_inputs; back_state ]
- in
+ let inputs = List.concat [ back_inputs; back_state ] in
(* Retrieve the values given back by this function *)
let ctx, outputs = abs_to_given_back None abs ctx in
(* Group the output values together: first the updated inputs *)
@@ -2670,77 +2592,46 @@ and translate_end_abstraction_loop (ectx : C.eval_ctx) (abs : V.abs)
- or a call to the variable we introduced for the backward function,
if we merge the forward/backward functions *)
let func =
- if !Config.return_back_funs then
- RegionGroupId.Map.find rg_id (Option.get loop_info.back_funs)
- else
- let func_ty = mk_arrows input_tys ret_ty in
- let func = Fun (FromLlbc (FunId fun_id, Some loop_id, Some rg_id)) in
- let func = { id = FunOrOp func; generics } in
- Some { e = Qualif func; ty = func_ty }
+ RegionGroupId.Map.find rg_id (Option.get loop_info.back_funs)
in
- (* **Optimization**:
- =================
- We do a small optimization here in case we split the forward/backward
- functions.
- If the backward function doesn't have any output, we don't introduce
- any function call.
- See the comment in {!Config.filter_useless_monadic_calls}.
-
- TODO: use an option to disallow backward functions from updating the state.
- TODO: a backward function which only gives back shared borrows shouldn't
- update the state (state updates should only be used for mutable borrows,
- with objects like Rc for instance).
- *)
- if
- (not !Config.return_back_funs)
- && !Config.filter_useless_monadic_calls
- && outputs = [] && nstate = None
- then (
- (* No outputs - we do a small sanity check: the backward function
- should have exactly the same number of inputs as the forward:
- this number can be different only if the forward function returned
- a value containing mutable borrows, which can't be the case... *)
- assert (List.length inputs = List.length fwd_inputs);
- next_e)
- else
- (* In case we merge the fwd/back functions we filter the backward
- functions elsewhere *)
- match func with
- | None -> next_e
- | Some func ->
- let call = mk_apps func args in
- (* Add meta-information - this is slightly hacky: we look at the
- values consumed by the abstraction (note that those come from
- *before* we applied the fixed-point context) and use them to
- guide the naming of the output vars.
-
- Also, we need to convert the backward outputs from patterns to
- variables.
-
- Finally, in practice, this works well only for loop bodies:
- we do this only in this case.
- TODO: improve the heuristics, to give weight to the hints for
- instance.
- *)
- let next_e =
- if ctx.inside_loop then
- let consumed_values = abs_to_consumed ctx ectx abs in
- let var_values = List.combine outputs consumed_values in
- let var_values =
- List.filter_map
- (fun (var, v) ->
- match var.Pure.value with
- | PatVar (var, _) -> Some (var, v)
- | _ -> None)
- var_values
- in
- let vars, values = List.split var_values in
- mk_emeta_symbolic_assignments vars values next_e
- else next_e
- in
+ (* We may have filtered the backward function elsewhere if it doesn't
+ do anything (doesn't consume anything and doesn't return anything) *)
+ match func with
+ | None -> next_e
+ | Some func ->
+ let call = mk_apps func args in
+ (* Add meta-information - this is slightly hacky: we look at the
+ values consumed by the abstraction (note that those come from
+ *before* we applied the fixed-point context) and use them to
+ guide the naming of the output vars.
+
+ Also, we need to convert the backward outputs from patterns to
+ variables.
+
+ Finally, in practice, this works well only for loop bodies:
+ we do this only in this case.
+ TODO: improve the heuristics, to give weight to the hints for
+ instance.
+ *)
+ let next_e =
+ if ctx.inside_loop then
+ let consumed_values = abs_to_consumed ctx ectx abs in
+ let var_values = List.combine outputs consumed_values in
+ let var_values =
+ List.filter_map
+ (fun (var, v) ->
+ match var.Pure.value with
+ | PatVar (var, _) -> Some (var, v)
+ | _ -> None)
+ var_values
+ in
+ let vars, values = List.split var_values in
+ mk_emeta_symbolic_assignments vars values next_e
+ else next_e
+ in
- (* Create the let-binding *)
- mk_let effect_info.can_fail output call next_e)
+ (* Create the let-binding *)
+ mk_let effect_info.can_fail output call next_e)
and translate_global_eval (gid : A.GlobalDeclId.id) (sval : V.symbolic_value)
(e : S.expression) (ctx : bs_ctx) : texpression =
@@ -3068,48 +2959,40 @@ and translate_forward_end (ectx : C.eval_ctx)
*)
let ctx =
(* Introduce variables for the inputs and the state variable
- and update the context. *)
- if !Config.return_back_funs then
- (* If the forward/backward functions are not split, we need
- to introduce fresh variables for the additional inputs,
- because they are locally introduced in a lambda *)
- let back_sg = RegionGroupId.Map.find bid ctx.sg.back_sg in
- let ctx, backward_inputs_no_state =
- fresh_vars back_sg.inputs_no_state ctx
- in
- let ctx, backward_inputs_with_state =
- if back_sg.effect_info.stateful then
- let ctx, var, _ = bs_ctx_fresh_state_var ctx in
- (ctx, backward_inputs_no_state @ [ var ])
- else (ctx, backward_inputs_no_state)
- in
- {
- ctx with
- backward_inputs_no_state =
- RegionGroupId.Map.add bid backward_inputs_no_state
- ctx.backward_inputs_no_state;
- backward_inputs_with_state =
- RegionGroupId.Map.add bid backward_inputs_with_state
- ctx.backward_inputs_with_state;
- }
- else
- (* Update the state variable *)
- let back_state_var =
- RegionGroupId.Map.find bid ctx.back_state_vars
- in
- { ctx with state_var = back_state_var }
+ and update the context.
+
+ We need to introduce fresh variables for the additional inputs,
+ because they are locally introduced in a lambda.
+ *)
+ let back_sg = RegionGroupId.Map.find bid ctx.sg.back_sg in
+ let ctx, backward_inputs_no_state =
+ fresh_vars back_sg.inputs_no_state ctx
+ in
+ let ctx, backward_inputs_with_state =
+ if back_sg.effect_info.stateful then
+ let ctx, var, _ = bs_ctx_fresh_state_var ctx in
+ (ctx, backward_inputs_no_state @ [ var ])
+ else (ctx, backward_inputs_no_state)
+ in
+ {
+ ctx with
+ backward_inputs_no_state =
+ RegionGroupId.Map.add bid backward_inputs_no_state
+ ctx.backward_inputs_no_state;
+ backward_inputs_with_state =
+ RegionGroupId.Map.add bid backward_inputs_with_state
+ ctx.backward_inputs_with_state;
+ }
in
let e = T.RegionGroupId.Map.find bid back_e in
let finish e =
(* Wrap in lambdas if necessary *)
- if !Config.return_back_funs then
- let inputs =
- RegionGroupId.Map.find bid ctx.backward_inputs_with_state
- in
- let places = List.map (fun _ -> None) inputs in
- mk_lambdas_from_vars inputs places e
- else e
+ let inputs =
+ RegionGroupId.Map.find bid ctx.backward_inputs_with_state
+ in
+ let places = List.map (fun _ -> None) inputs in
+ mk_lambdas_from_vars inputs places e
in
(ctx, e, finish)
in
@@ -3131,85 +3014,83 @@ and translate_forward_end (ectx : C.eval_ctx)
function, if needs be, and lookup the proper expression.
*)
let translate_end ctx =
- if !Config.return_back_funs then
- (* Compute the output of the forward function *)
- let fwd_effect_info = ctx.sg.fwd_info.effect_info in
- let ctx, pure_fwd_var = fresh_var None ctx.sg.fwd_output ctx in
- let fwd_e = translate_one_end ctx None in
-
- (* Introduce the backward functions. *)
- let back_el =
- List.map
- (fun ((gid, _) : RegionGroupId.id * back_sg_info) ->
- translate_one_end ctx (Some gid))
- (RegionGroupId.Map.bindings ctx.sg.back_sg)
- in
+ (* Compute the output of the forward function *)
+ let fwd_effect_info = ctx.sg.fwd_info.effect_info in
+ let ctx, pure_fwd_var = fresh_var None ctx.sg.fwd_output ctx in
+ let fwd_e = translate_one_end ctx None in
- (* Compute whether the backward expressions should be evaluated straight
- away or not (i.e., if we should bind them with monadic let-bindings
- or not). We evaluate them straight away if they can fail and have no
- inputs. *)
- let evaluate_backs =
- List.map
- (fun (sg : back_sg_info) ->
- if !Config.simplify_merged_fwd_backs then
- sg.inputs = [] && sg.effect_info.can_fail
- else false)
- (RegionGroupId.Map.values ctx.sg.back_sg)
- in
+ (* Introduce the backward functions. *)
+ let back_el =
+ List.map
+ (fun ((gid, _) : RegionGroupId.id * back_sg_info) ->
+ translate_one_end ctx (Some gid))
+ (RegionGroupId.Map.bindings ctx.sg.back_sg)
+ in
- (* Introduce variables for the backward functions.
- We lookup the LLBC definition in an attempt to derive pretty names
- for those functions. *)
- let _, back_vars = fresh_back_vars_for_current_fun ctx in
+ (* Compute whether the backward expressions should be evaluated straight
+ away or not (i.e., if we should bind them with monadic let-bindings
+ or not). We evaluate them straight away if they can fail and have no
+ inputs. *)
+ let evaluate_backs =
+ List.map
+ (fun (sg : back_sg_info) ->
+ if !Config.simplify_merged_fwd_backs then
+ sg.inputs = [] && sg.effect_info.can_fail
+ else false)
+ (RegionGroupId.Map.values ctx.sg.back_sg)
+ in
- (* Create the return expressions *)
- let vars =
- let back_vars = List.filter_map (fun x -> x) back_vars in
- if ctx.sg.fwd_info.ignore_output then back_vars
- else pure_fwd_var :: back_vars
- in
- let vars = List.map mk_texpression_from_var vars in
- let ret = mk_simpl_tuple_texpression vars in
-
- (* Introduce a fresh input state variable for the forward expression *)
- let _ctx, state_var, state_pat =
- if fwd_effect_info.stateful then
- let ctx, var, pat = bs_ctx_fresh_state_var ctx in
- (ctx, [ var ], [ pat ])
- else (ctx, [], [])
- in
+ (* Introduce variables for the backward functions.
+ We lookup the LLBC definition in an attempt to derive pretty names
+ for those functions. *)
+ let _, back_vars = fresh_back_vars_for_current_fun ctx in
+
+ (* Create the return expressions *)
+ let vars =
+ let back_vars = List.filter_map (fun x -> x) back_vars in
+ if ctx.sg.fwd_info.ignore_output then back_vars
+ else pure_fwd_var :: back_vars
+ in
+ let vars = List.map mk_texpression_from_var vars in
+ let ret = mk_simpl_tuple_texpression vars in
+
+ (* Introduce a fresh input state variable for the forward expression *)
+ let _ctx, state_var, state_pat =
+ if fwd_effect_info.stateful then
+ let ctx, var, pat = bs_ctx_fresh_state_var ctx in
+ (ctx, [ var ], [ pat ])
+ else (ctx, [], [])
+ in
- let state_var = List.map mk_texpression_from_var state_var in
- let ret = mk_simpl_tuple_texpression (state_var @ [ ret ]) in
- let ret = mk_result_return_texpression ret in
+ let state_var = List.map mk_texpression_from_var state_var in
+ let ret = mk_simpl_tuple_texpression (state_var @ [ ret ]) in
+ let ret = mk_result_return_texpression ret in
- (* Introduce all the let-bindings *)
+ (* Introduce all the let-bindings *)
- (* Combine:
- - the backward variables
- - whether we should evaluate the expression for the backward function
- (i.e., should we use a monadic let-binding or not - we do if the
- backward functions don't have inputs and can fail)
- - the expressions for the backward functions
- *)
- let back_vars_els =
- List.filter_map
- (fun (v, (eval, el)) ->
- match v with None -> None | Some v -> Some (v, eval, el))
- (List.combine back_vars (List.combine evaluate_backs back_el))
- in
- let e =
- List.fold_right
- (fun (var, evaluate, back_e) e ->
- mk_let evaluate (mk_typed_pattern_from_var var None) back_e e)
- back_vars_els ret
- in
- (* Bind the expression for the forward output *)
- let fwd_var = mk_typed_pattern_from_var pure_fwd_var None in
- let pat = mk_simpl_tuple_pattern (state_pat @ [ fwd_var ]) in
- mk_let fwd_effect_info.can_fail pat fwd_e e
- else translate_one_end ctx ctx.bid
+ (* Combine:
+ - the backward variables
+ - whether we should evaluate the expression for the backward function
+ (i.e., should we use a monadic let-binding or not - we do if the
+ backward functions don't have inputs and can fail)
+ - the expressions for the backward functions
+ *)
+ let back_vars_els =
+ List.filter_map
+ (fun (v, (eval, el)) ->
+ match v with None -> None | Some v -> Some (v, eval, el))
+ (List.combine back_vars (List.combine evaluate_backs back_el))
+ in
+ let e =
+ List.fold_right
+ (fun (var, evaluate, back_e) e ->
+ mk_let evaluate (mk_typed_pattern_from_var var None) back_e e)
+ back_vars_els ret
+ in
+ (* Bind the expression for the forward output *)
+ let fwd_var = mk_typed_pattern_from_var pure_fwd_var None in
+ let pat = mk_simpl_tuple_pattern (state_pat @ [ fwd_var ]) in
+ mk_let fwd_effect_info.can_fail pat fwd_e e
in
(* If we are (re-)entering a loop, we need to introduce a call to the
@@ -3279,24 +3160,22 @@ and translate_forward_end (ectx : C.eval_ctx)
backward functions of the outer function.
*)
let ctx, back_funs_map, back_funs =
- if !Config.return_back_funs then
- let ctx, back_vars = fresh_back_vars_for_current_fun ctx in
- let back_funs =
- List.filter_map
- (fun v ->
- match v with
- | None -> None
- | Some v -> Some (mk_typed_pattern_from_var v None))
- back_vars
- in
- let gids = RegionGroupId.Map.keys ctx.sg.back_sg in
- let back_funs_map =
- RegionGroupId.Map.of_list
- (List.combine gids
- (List.map (Option.map mk_texpression_from_var) back_vars))
- in
- (ctx, Some back_funs_map, back_funs)
- else (ctx, None, [])
+ let ctx, back_vars = fresh_back_vars_for_current_fun ctx in
+ let back_funs =
+ List.filter_map
+ (fun v ->
+ match v with
+ | None -> None
+ | Some v -> Some (mk_typed_pattern_from_var v None))
+ back_vars
+ in
+ let gids = RegionGroupId.Map.keys ctx.sg.back_sg in
+ let back_funs_map =
+ RegionGroupId.Map.of_list
+ (List.combine gids
+ (List.map (Option.map mk_texpression_from_var) back_vars))
+ in
+ (ctx, Some back_funs_map, back_funs)
in
(* Introduce patterns *)
@@ -3438,91 +3317,58 @@ and translate_loop (loop : S.loop) (ctx : bs_ctx) : texpression =
(* The output type of the loop function *)
let fwd_effect_info = { ctx.sg.fwd_info.effect_info with is_rec = true } in
let back_effect_infos, output_ty =
- if !Config.return_back_funs then
- (* The loop backward functions consume the same additional inputs as the parent
- function, but have custom outputs *)
- let back_sgs = RegionGroupId.Map.bindings ctx.sg.back_sg in
- let given_back_tys = RegionGroupId.Map.values rg_to_given_back_tys in
- let back_info_tys =
- List.map
- (fun (((id, back_sg), given_back) : (_ * back_sg_info) * ty list) ->
- (* Remark: the effect info of the backward function for the loop
- is almost the same as for the backward function of the parent function.
- Quite importantly, the fact that the function is stateful and/or can fail
- mostly depends on whether it has inputs or not, and the backward functions
- for the loops have the same inputs as the backward functions for the parent
- function.
- *)
- let effect_info = back_sg.effect_info in
- let effect_info = { effect_info with is_rec = true } in
- (* Compute the input/output types *)
- let inputs = List.map snd back_sg.inputs in
- let outputs = given_back in
- (* Filter if necessary *)
- let ty =
- if
- !Config.simplify_merged_fwd_backs && inputs = [] && outputs = []
- then None
- else
- let output = mk_simpl_tuple_ty outputs in
- let output =
- mk_back_output_ty_from_effect_info effect_info inputs output
- in
- let ty = mk_arrows inputs output in
- Some ty
- in
- ((id, effect_info), ty))
- (List.combine back_sgs given_back_tys)
- in
- let back_info = List.map fst back_info_tys in
- let back_info = RegionGroupId.Map.of_list back_info in
- let back_tys = List.filter_map snd back_info_tys in
- let output =
- if ctx.sg.fwd_info.ignore_output then back_tys
- else ctx.sg.fwd_output :: back_tys
- in
- let output = mk_simpl_tuple_ty output in
- let effect_info = ctx.sg.fwd_info.effect_info in
- let output =
- if effect_info.stateful then mk_simpl_tuple_ty [ mk_state_ty; output ]
- else output
- in
- let output =
- if effect_info.can_fail && inputs <> [] then mk_result_ty output
- else output
- in
- (back_info, output)
- else
- let back_info =
- RegionGroupId.Map.of_list
- (List.map
- (fun ((id, back_sg) : _ * back_sg_info) ->
- (id, { back_sg.effect_info with is_rec = true }))
- (RegionGroupId.Map.bindings ctx.sg.back_sg))
- in
- let output =
- match ctx.bid with
- | None ->
- (* Forward function: same type as the parent function *)
- (translate_fun_sig_from_decomposed ctx.sg None).output
- | Some rg_id ->
- (* Backward function: custom return type *)
- let doutputs =
- T.RegionGroupId.Map.find rg_id rg_to_given_back_tys
- in
- let output = mk_simpl_tuple_ty doutputs in
- let fwd_effect_info = ctx.sg.fwd_info.effect_info in
- let output =
- if fwd_effect_info.stateful then
- mk_simpl_tuple_ty [ mk_state_ty; output ]
- else output
- in
- let output =
- if fwd_effect_info.can_fail then mk_result_ty output else output
- in
- output
- in
- (back_info, output)
+ (* The loop backward functions consume the same additional inputs as the parent
+ function, but have custom outputs *)
+ let back_sgs = RegionGroupId.Map.bindings ctx.sg.back_sg in
+ let given_back_tys = RegionGroupId.Map.values rg_to_given_back_tys in
+ let back_info_tys =
+ List.map
+ (fun (((id, back_sg), given_back) : (_ * back_sg_info) * ty list) ->
+ (* Remark: the effect info of the backward function for the loop
+ is almost the same as for the backward function of the parent function.
+ Quite importantly, the fact that the function is stateful and/or can fail
+ mostly depends on whether it has inputs or not, and the backward functions
+ for the loops have the same inputs as the backward functions for the parent
+ function.
+ *)
+ let effect_info = back_sg.effect_info in
+ let effect_info = { effect_info with is_rec = true } in
+ (* Compute the input/output types *)
+ let inputs = List.map snd back_sg.inputs in
+ let outputs = given_back in
+ (* Filter if necessary *)
+ let ty =
+ if !Config.simplify_merged_fwd_backs && inputs = [] && outputs = []
+ then None
+ else
+ let output = mk_simpl_tuple_ty outputs in
+ let output =
+ mk_back_output_ty_from_effect_info effect_info inputs output
+ in
+ let ty = mk_arrows inputs output in
+ Some ty
+ in
+ ((id, effect_info), ty))
+ (List.combine back_sgs given_back_tys)
+ in
+ let back_info = List.map fst back_info_tys in
+ let back_info = RegionGroupId.Map.of_list back_info in
+ let back_tys = List.filter_map snd back_info_tys in
+ let output =
+ if ctx.sg.fwd_info.ignore_output then back_tys
+ else ctx.sg.fwd_output :: back_tys
+ in
+ let output = mk_simpl_tuple_ty output in
+ let effect_info = ctx.sg.fwd_info.effect_info in
+ let output =
+ if effect_info.stateful then mk_simpl_tuple_ty [ mk_state_ty; output ]
+ else output
+ in
+ let output =
+ if effect_info.can_fail && inputs <> [] then mk_result_ty output
+ else output
+ in
+ (back_info, output)
in
(* Add the loop information in the context *)
@@ -3708,21 +3554,19 @@ let wrap_in_match_fuel (fuel0 : VarId.id) (fuel : VarId.id) (body : texpression)
let translate_fun_decl (ctx : bs_ctx) (body : S.expression option) : fun_decl =
(* Translate *)
let def = ctx.fun_decl in
- let bid = ctx.bid in
+ assert (ctx.bid = None);
log#ldebug
(lazy
("SymbolicToPure.translate_fun_decl: "
^ name_to_string ctx def.name
- ^ " ("
- ^ Print.option_to_string T.RegionGroupId.to_string bid
- ^ ")\n"));
+ ^ "\n"));
(* Translate the declaration *)
let def_id = def.def_id in
let llbc_name = def.name in
let name = name_to_string ctx llbc_name in
(* Translate the signature *)
- let signature = translate_fun_sig_from_decomposed ctx.sg ctx.bid in
+ let signature = translate_fun_sig_from_decomposed ctx.sg in
let regions_hierarchy =
FunIdMap.find (FRegular def_id) ctx.fun_ctx.regions_hierarchies
in
@@ -3732,7 +3576,7 @@ let translate_fun_decl (ctx : bs_ctx) (body : S.expression option) : fun_decl =
| None -> None
| Some body ->
let effect_info =
- get_fun_effect_info ctx (FunId (FRegular def_id)) None bid
+ get_fun_effect_info ctx (FunId (FRegular def_id)) None None
in
let body = translate_expression body ctx in
(* Add a match over the fuel, if necessary *)
@@ -3760,37 +3604,8 @@ let translate_fun_decl (ctx : bs_ctx) (body : S.expression option) : fun_decl =
if effect_info.stateful_group then [ mk_state_var ctx.state_var ]
else []
in
- (* Compute the list of (properly ordered) backward input variables *)
- let backward_inputs : var list =
- match bid with
- | None -> []
- | Some back_id ->
- assert (not !Config.return_back_funs);
- let parents_ids =
- list_ordered_ancestor_region_groups regions_hierarchy back_id
- in
- let backward_ids = List.append parents_ids [ back_id ] in
- List.concat
- (List.map
- (fun id ->
- T.RegionGroupId.Map.find id ctx.backward_inputs_no_state)
- backward_ids)
- in
- (* Introduce the backward input state (the state at call site of the
- * *backward* function), if necessary *)
- let back_state =
- if effect_info.stateful && Option.is_some bid then
- let state_var =
- RegionGroupId.Map.find (Option.get bid) ctx.back_state_vars
- in
- [ mk_state_var state_var ]
- else []
- in
(* Group the inputs together *)
- let inputs =
- List.concat
- [ fuel; ctx.forward_inputs; fwd_state; backward_inputs; back_state ]
- in
+ let inputs = List.concat [ fuel; ctx.forward_inputs; fwd_state ] in
let inputs_lvs =
List.map (fun v -> mk_typed_pattern_from_var v None) inputs
in
@@ -3799,16 +3614,10 @@ let translate_fun_decl (ctx : bs_ctx) (body : S.expression option) : fun_decl =
(lazy
("SymbolicToPure.translate_fun_decl: "
^ name_to_string ctx def.name
- ^ " ("
- ^ Print.option_to_string T.RegionGroupId.to_string bid
- ^ ")" ^ "\n- forward_inputs: "
+ ^ "\n- inputs: "
^ String.concat ", " (List.map show_var ctx.forward_inputs)
- ^ "\n- fwd_state: "
+ ^ "\n- state: "
^ String.concat ", " (List.map show_var fwd_state)
- ^ "\n- backward_inputs: "
- ^ String.concat ", " (List.map show_var backward_inputs)
- ^ "\n- back_state: "
- ^ String.concat ", " (List.map show_var back_state)
^ "\n- signature.inputs: "
^ String.concat ", "
(List.map (pure_ty_to_string ctx) signature.inputs)));
@@ -3837,7 +3646,6 @@ let translate_fun_decl (ctx : bs_ctx) (body : S.expression option) : fun_decl =
kind = def.kind;
num_loops;
loop_id;
- back_id = bid;
llbc_name;
name;
signature;