diff options
author | Son Ho | 2024-03-08 13:41:57 +0100 |
---|---|---|
committer | Son Ho | 2024-03-08 13:41:57 +0100 |
commit | bc154dda94c44b3ae67a3b04d3866cc473aead32 (patch) | |
tree | 8d5eb4febc93e2f274a1918ea5353b9746f324d0 /compiler | |
parent | b604bb9935007a1f0e9c7f556f8196f0e14c85ce (diff) |
Remove the option to split fwd/back functions and update SymbolicToPure
Diffstat (limited to 'compiler')
-rw-r--r-- | compiler/Config.ml | 63 | ||||
-rw-r--r-- | compiler/ExtractBase.ml | 82 | ||||
-rw-r--r-- | compiler/ExtractBuiltin.ml | 131 | ||||
-rw-r--r-- | compiler/Main.ml | 3 | ||||
-rw-r--r-- | compiler/SymbolicToPure.ml | 936 |
5 files changed, 439 insertions, 776 deletions
diff --git a/compiler/Config.ml b/compiler/Config.ml index 3b0070c0..6fd866e8 100644 --- a/compiler/Config.ml +++ b/compiler/Config.ml @@ -92,69 +92,6 @@ let loop_fixed_point_max_num_iters = 2 (** {1 Translation} *) -(** If true, do not define separate forward/backward functions, but make the - forward functions return the backward function. - - Example: - {[ - (* Rust *) - pub fn list_nth<'a, T>(l: &'a mut List<T>, i: u32) -> &'a mut T { - match l { - List::Nil => { - panic!() - } - List::Cons(x, tl) => { - if i == 0 { - x - } else { - list_nth(tl, i - 1) - } - } - } - } - - (* Translation, if return_back_funs = false *) - def list_nth (T : Type) (l : List T) (i : U32) : Result T := - match l with - | List.Cons x tl => - if i = 0#u32 - then Result.ret x - else do - let i0 ← i - 1#u32 - list_nth T tl i0 - | List.Nil => Result.fail .panic - - def list_nth_back - (T : Type) (l : List T) (i : U32) (ret : T) : Result (List T) := - match l with - | List.Cons x tl => - if i = 0#u32 - then Result.ret (List.Cons ret tl) - else - do - let i0 ← i - 1#u32 - let tl0 ← list_nth_back T tl i0 ret - Result.ret (List.Cons x tl0) - | List.Nil => Result.fail .panic - - (* Translation, if return_back_funs = true *) - def list_nth (T: Type) (ls : List T) (i : U32) : - Result (T × (T → Result (List T))) := - match ls with - | List.Cons x tl => - if i = 0#u32 - then Result.ret (x, (λ ret => return (ret :: ls))) - else do - let i0 ← i - 1#u32 - let (x, back) ← list_nth ls i0 - Return.ret (x, - (λ ret => do - let ls ← back ret - return (x :: ls))) - ]} - *) -let return_back_funs = ref true - (** Forbids using field projectors for structures. If we don't use field projectors, whenever we symbolically expand a structure diff --git a/compiler/ExtractBase.ml b/compiler/ExtractBase.ml index 5aa8323e..04686705 100644 --- a/compiler/ExtractBase.ml +++ b/compiler/ExtractBase.ml @@ -471,8 +471,7 @@ type names_map_init = { assumed_adts : (assumed_ty * string) list; assumed_structs : (assumed_ty * string) list; assumed_variants : (assumed_ty * VariantId.id * string) list; - assumed_llbc_functions : - (A.assumed_fun_id * RegionGroupId.id option * string) list; + assumed_llbc_functions : (A.assumed_fun_id * string) list; assumed_pure_functions : (pure_assumed_fun_id * string) list; } @@ -1052,63 +1051,28 @@ let assumed_variants () : (assumed_ty * VariantId.id * string) list = (* No Fuel::Succ on purpose *) ] -let assumed_llbc_functions () : - (A.assumed_fun_id * T.RegionGroupId.id option * string) list = - let rg0 = Some T.RegionGroupId.zero in - let regular : (A.assumed_fun_id * T.RegionGroupId.id option * string) list = - match !backend with - | FStar | Coq | HOL4 -> - [ - (ArrayIndexShared, None, "array_index_usize"); - (ArrayToSliceShared, None, "array_to_slice"); - (ArrayRepeat, None, "array_repeat"); - (SliceIndexShared, None, "slice_index_usize"); - ] - | Lean -> - [ - (ArrayIndexShared, None, "Array.index_usize"); - (ArrayToSliceShared, None, "Array.to_slice"); - (ArrayRepeat, None, "Array.repeat"); - (SliceIndexShared, None, "Slice.index_usize"); - ] - in - let mut_funs : (A.assumed_fun_id * T.RegionGroupId.id option * string) list = - if !Config.return_back_funs then - match !backend with - | FStar | Coq | HOL4 -> - [ - (ArrayIndexMut, None, "array_index_mut_usize"); - (ArrayToSliceMut, None, "array_to_slice_mut"); - (SliceIndexMut, None, "slice_index_mut_usize"); - ] - | Lean -> - [ - (ArrayIndexMut, None, "Array.index_mut_usize"); - (ArrayToSliceMut, None, "Array.to_slice_mut"); - (SliceIndexMut, None, "Slice.index_mut_usize"); - ] - else - match !backend with - | FStar | Coq | HOL4 -> - [ - (ArrayIndexMut, None, "array_index_usize"); - (ArrayIndexMut, rg0, "array_update_usize"); - (ArrayToSliceMut, None, "array_to_slice"); - (ArrayToSliceMut, rg0, "array_from_slice"); - (SliceIndexMut, None, "slice_index_usize"); - (SliceIndexMut, rg0, "slice_update_usize"); - ] - | Lean -> - [ - (ArrayIndexMut, None, "Array.index_usize"); - (ArrayIndexMut, rg0, "Array.update_usize"); - (ArrayToSliceMut, None, "Array.to_slice"); - (ArrayToSliceMut, rg0, "Array.from_slice"); - (SliceIndexMut, None, "Slice.index_usize"); - (SliceIndexMut, rg0, "Slice.update_usize"); - ] - in - regular @ mut_funs +let assumed_llbc_functions () : (A.assumed_fun_id * string) list = + match !backend with + | FStar | Coq | HOL4 -> + [ + (ArrayIndexShared, "array_index_usize"); + (ArrayIndexMut, "array_index_mut_usize"); + (ArrayToSliceShared, "array_to_slice"); + (ArrayToSliceMut, "array_to_slice_mut"); + (ArrayRepeat, "array_repeat"); + (SliceIndexShared, "slice_index_usize"); + (SliceIndexMut, "slice_index_mut_usize"); + ] + | Lean -> + [ + (ArrayIndexShared, "Array.index_usize"); + (ArrayIndexMut, "Array.index_mut_usize"); + (ArrayToSliceShared, "Array.to_slice"); + (ArrayToSliceMut, "Array.to_slice_mut"); + (ArrayRepeat, "Array.repeat"); + (SliceIndexShared, "Slice.index_usize"); + (SliceIndexMut, "Slice.index_mut_usize"); + ] let assumed_pure_functions () : (pure_assumed_fun_id * string) list = match !backend with diff --git a/compiler/ExtractBuiltin.ml b/compiler/ExtractBuiltin.ml index ee8d4831..3ea5655a 100644 --- a/compiler/ExtractBuiltin.ml +++ b/compiler/ExtractBuiltin.ml @@ -213,11 +213,7 @@ let mk_builtin_types_map () = let builtin_types_map = mk_memoized mk_builtin_types_map -type builtin_fun_info = { - rg : Types.RegionGroupId.id option; - extract_name : string; -} -[@@deriving show] +type builtin_fun_info = { extract_name : string } [@@deriving show] (** The assumed functions. @@ -227,19 +223,10 @@ type builtin_fun_info = { *) let builtin_funs () : (pattern * bool list option * builtin_fun_info list) list = - let rg0 = Some Types.RegionGroupId.zero in (* Small utility *) let mk_fun (rust_name : string) (extract_name : string option) - (filter : bool list option) (with_back : bool) (back_no_suffix : bool) : + (filter : bool list option) : pattern * bool list option * builtin_fun_info list = - (* [back_no_suffix] is used to control whether the backward function should - have the suffix "_back" or not (if not, then the forward function has the - prefix "_fwd", and is filtered anyway). This is pertinent only if we split - the fwd/back functions. *) - let back_no_suffix = back_no_suffix && not !Config.return_back_funs in - (* Same for the [with_back] option: this is pertinent only if we split - the fwd/back functions *) - let with_back = with_back && not !Config.return_back_funs in let rust_name = try parse_pattern rust_name with Failure _ -> @@ -251,68 +238,51 @@ let builtin_funs () : (pattern * bool list option * builtin_fun_info list) list | Some name -> split_on_separator name in let basename = flatten_name extract_name in - let fwd_suffix = if with_back && back_no_suffix then "_fwd" else "" in - let fwd = [ { rg = None; extract_name = basename ^ fwd_suffix } ] in - let back_suffix = if with_back && back_no_suffix then "" else "_back" in - let back = - if with_back then [ { rg = rg0; extract_name = basename ^ back_suffix } ] - else [] - in - (rust_name, filter, fwd @ back) + let f = [ { extract_name = basename } ] in + (rust_name, filter, f) in [ - mk_fun "core::mem::replace" None None true false; + mk_fun "core::mem::replace" None None; mk_fun "core::slice::{[@T]}::len" (Some (backend_choice "slice::len" "Slice::len")) - None true false; + None; mk_fun "alloc::vec::{alloc::vec::Vec<@T, alloc::alloc::Global>}::new" - (Some "alloc::vec::Vec::new") None false false; + (Some "alloc::vec::Vec::new") None; mk_fun "alloc::vec::{alloc::vec::Vec<@T, @A>}::push" None - (Some [ true; false ]) - true true; + (Some [ true; false ]); mk_fun "alloc::vec::{alloc::vec::Vec<@T, @A>}::insert" None - (Some [ true; false ]) - true true; + (Some [ true; false ]); mk_fun "alloc::vec::{alloc::vec::Vec<@T, @A>}::len" None - (Some [ true; false ]) - true false; + (Some [ true; false ]); mk_fun "alloc::vec::{alloc::vec::Vec<@T, @A>}::index" None - (Some [ true; true; false ]) - true false; + (Some [ true; true; false ]); mk_fun "alloc::vec::{alloc::vec::Vec<@T, @A>}::index_mut" None - (Some [ true; true; false ]) - true false; - mk_fun "alloc::boxed::{Box<@T>}::deref" None - (Some [ true; false ]) - true false; - mk_fun "alloc::boxed::{Box<@T>}::deref_mut" None - (Some [ true; false ]) - true false; - mk_fun "core::slice::index::{[@T]}::index" None None true false; - mk_fun "core::slice::index::{[@T]}::index_mut" None None true false; - mk_fun "core::array::{[@T; @C]}::index" None None true false; - mk_fun "core::array::{[@T; @C]}::index_mut" None None true false; + (Some [ true; true; false ]); + mk_fun "alloc::boxed::{Box<@T>}::deref" None (Some [ true; false ]); + mk_fun "alloc::boxed::{Box<@T>}::deref_mut" None (Some [ true; false ]); + mk_fun "core::slice::index::{[@T]}::index" None None; + mk_fun "core::slice::index::{[@T]}::index_mut" None None; + mk_fun "core::array::{[@T; @C]}::index" None None; + mk_fun "core::array::{[@T; @C]}::index_mut" None None; mk_fun "core::slice::index::{core::ops::range::Range<usize>}::get" - (Some "core::slice::index::RangeUsize::get") None true false; + (Some "core::slice::index::RangeUsize::get") None; mk_fun "core::slice::index::{core::ops::range::Range<usize>}::get_mut" - (Some "core::slice::index::RangeUsize::get_mut") None true false; + (Some "core::slice::index::RangeUsize::get_mut") None; mk_fun "core::slice::index::{core::ops::range::Range<usize>}::index" - (Some "core::slice::index::RangeUsize::index") None true false; + (Some "core::slice::index::RangeUsize::index") None; mk_fun "core::slice::index::{core::ops::range::Range<usize>}::index_mut" - (Some "core::slice::index::RangeUsize::index_mut") None true false; + (Some "core::slice::index::RangeUsize::index_mut") None; mk_fun "core::slice::index::{core::ops::range::Range<usize>}::get_unchecked" - (Some "core::slice::index::RangeUsize::get_unchecked") None false false; + (Some "core::slice::index::RangeUsize::get_unchecked") None; mk_fun "core::slice::index::{core::ops::range::Range<usize>}::get_unchecked_mut" - (Some "core::slice::index::RangeUsize::get_unchecked_mut") None false - false; - mk_fun "core::slice::index::{usize}::get" None None true false; - mk_fun "core::slice::index::{usize}::get_mut" None None true false; - mk_fun "core::slice::index::{usize}::get_unchecked" None None false false; - mk_fun "core::slice::index::{usize}::get_unchecked_mut" None None false - false; - mk_fun "core::slice::index::{usize}::index" None None true false; - mk_fun "core::slice::index::{usize}::index_mut" None None true false; + (Some "core::slice::index::RangeUsize::get_unchecked_mut") None; + mk_fun "core::slice::index::{usize}::get" None None; + mk_fun "core::slice::index::{usize}::get_mut" None None; + mk_fun "core::slice::index::{usize}::get_unchecked" None None; + mk_fun "core::slice::index::{usize}::get_unchecked_mut" None None; + mk_fun "core::slice::index::{usize}::index" None None; + mk_fun "core::slice::index::{usize}::index_mut" None None; ] let mk_builtin_funs_map () = @@ -412,10 +382,9 @@ type builtin_trait_decl_info = { [@@deriving show] let builtin_trait_decls_info () = - let rg0 = Some Types.RegionGroupId.zero in let mk_trait (rust_name : string) ?(extract_name : string option = None) ?(parent_clauses : string list = []) ?(types : string list = []) - ?(methods : (string * bool) list = []) () : builtin_trait_decl_info = + ?(methods : string list = []) () : builtin_trait_decl_info = let rust_name = parse_pattern rust_name in let extract_name = match extract_name with @@ -443,22 +412,14 @@ let builtin_trait_decls_info () = List.map mk_type types in let methods = - let mk_method (item_name, with_back) = + let mk_method item_name = (* TODO: factor out with builtin_funs_info *) let basename = if !record_fields_short_names then item_name else extract_name ^ "_" ^ item_name in - let back_no_suffix = false in - let fwd_suffix = if with_back && back_no_suffix then "_fwd" else "" in - let fwd = [ { rg = None; extract_name = basename ^ fwd_suffix } ] in - let back_suffix = if with_back && back_no_suffix then "" else "_back" in - let back = - if with_back then - [ { rg = rg0; extract_name = basename ^ back_suffix } ] - else [] - in - (item_name, fwd @ back) + let fwd = [ { extract_name = basename } ] in + (item_name, fwd) in List.map mk_method methods in @@ -474,21 +435,17 @@ let builtin_trait_decls_info () = in [ (* Deref *) - mk_trait "core::ops::deref::Deref" ~types:[ "Target" ] - ~methods:[ ("deref", true) ] + mk_trait "core::ops::deref::Deref" ~types:[ "Target" ] ~methods:[ "deref" ] (); (* DerefMut *) mk_trait "core::ops::deref::DerefMut" ~parent_clauses:[ "derefInst" ] - ~methods:[ ("deref_mut", true) ] - (); + ~methods:[ "deref_mut" ] (); (* Index *) - mk_trait "core::ops::index::Index" ~types:[ "Output" ] - ~methods:[ ("index", true) ] + mk_trait "core::ops::index::Index" ~types:[ "Output" ] ~methods:[ "index" ] (); (* IndexMut *) mk_trait "core::ops::index::IndexMut" ~parent_clauses:[ "indexInst" ] - ~methods:[ ("index_mut", true) ] - (); + ~methods:[ "index_mut" ] (); (* Sealed *) mk_trait "core::slice::index::private_slice_index::Sealed" (); (* SliceIndex *) @@ -496,12 +453,12 @@ let builtin_trait_decls_info () = ~types:[ "Output" ] ~methods: [ - ("get", true); - ("get_mut", true); - ("get_unchecked", false); - ("get_unchecked_mut", false); - ("index", true); - ("index_mut", true); + "get"; + "get_mut"; + "get_unchecked"; + "get_unchecked_mut"; + "index"; + "index_mut"; ] (); ] diff --git a/compiler/Main.ml b/compiler/Main.ml index 4a2d01dc..664ec067 100644 --- a/compiler/Main.ml +++ b/compiler/Main.ml @@ -120,9 +120,6 @@ let () = " Generate a default lakefile.lean (Lean only)" ); ("-print-llbc", Arg.Set print_llbc, " Print the imported LLBC"); ("-k", Arg.Clear fail_hard, " Do not fail hard in case of error"); - ( "-split-fwd-back", - Arg.Clear return_back_funs, - " Split the forward and backward functions." ); ( "-tuple-nested-proj", Arg.Set use_nested_tuple_projectors, " Use nested projectors for tuples (e.g., (0, 1).snd.fst instead of \ diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml index 3a50e495..859d6f17 100644 --- a/compiler/SymbolicToPure.ml +++ b/compiler/SymbolicToPure.ml @@ -805,11 +805,9 @@ let bs_ctx_register_forward_call (call_id : V.FunCallId.id) (forward : S.call) that we need to call. This function may be [None] if it has to be ignored (because it does nothing). *) -let bs_ctx_register_backward_call (abs : V.abs) (effect_info : fun_effect_info) - (call_id : V.FunCallId.id) (back_id : T.RegionGroupId.id) - (inherited_args : texpression list) (back_args : texpression list) - (generics : generic_args) (output_ty : ty) (ctx : bs_ctx) : - bs_ctx * texpression option = +let bs_ctx_register_backward_call (abs : V.abs) (call_id : V.FunCallId.id) + (back_id : T.RegionGroupId.id) (back_args : texpression list) (ctx : bs_ctx) + : bs_ctx * texpression option = (* Insert the abstraction in the call informations *) let info = V.FunCallId.Map.find call_id ctx.calls in let calls = V.FunCallId.Map.add call_id info ctx.calls in @@ -819,29 +817,9 @@ let bs_ctx_register_backward_call (abs : V.abs) (effect_info : fun_effect_info) let abstractions = V.AbstractionId.Map.add abs.abs_id (abs, back_args) abstractions in - (* Compute the expression corresponding to the function *) - let func = - if !Config.return_back_funs then - (* Lookup the variable introduced for the backward function *) - RegionGroupId.Map.find back_id (Option.get info.back_funs) - else - (* Retrieve the fun_id *) - let fun_id = - match info.forward.call_id with - | S.Fun (fid, _) -> - let fid = translate_fun_id_or_trait_method_ref ctx fid in - Fun (FromLlbc (fid, None, Some back_id)) - | S.Unop _ | S.Binop _ -> raise (Failure "Unreachable") - in - let args = List.append inherited_args back_args in - let input_tys = (List.map (fun (x : texpression) -> x.ty)) args in - let ret_ty = - if effect_info.can_fail then mk_result_ty output_ty else output_ty - in - let func_ty = mk_arrows input_tys ret_ty in - let func = { id = FunOrOp fun_id; generics } in - Some { e = Qualif func; ty = func_ty } - in + (* Compute the expression corresponding to the function. + We simply lookup the variable introduced for the backward function. *) + let func = RegionGroupId.Map.find back_id (Option.get info.back_funs) in (* Update the context and return *) ({ ctx with calls; abstractions }, func) @@ -1124,20 +1102,34 @@ let translate_fun_sig_with_regions_hierarchy_to_decomposed let inputs_no_state = List.map (fun ty -> (Some "ret", ty)) inputs_no_state in - (* In case we merge the forward/backward functions: - we consider the backward function as stateful and potentially failing + (* We consider a backward function as stateful and potentially failing **only if it has inputs** (for the "potentially failing": if it has not inputs, we directly evaluate it in the body of the forward function). + + For instance, we do the following: + {[ + // Rust + fn push<T, 'a>(v : &mut Vec<T>, x : T) { ... } + + (* Generated code: before doing unit elimination. + We return (), as well as the backward function; as the backward + function doesn't consume any inputs, it is a value that we compute + directly in the body of [push]. + *) + let push T (v : Vec T) (x : T) : Result (() * Vec T) = ... + + (* Generated code: after doing unit elimination, if we simplify the merged + fwd/back functions (see below). *) + let push T (v : Vec T) (x : T) : Result (Vec T) = ... + ]} *) let back_effect_info = - if !Config.return_back_funs then - let b = inputs_no_state <> [] in - { - back_effect_info with - stateful = back_effect_info.stateful && b; - can_fail = back_effect_info.can_fail && b; - } - else back_effect_info + let b = inputs_no_state <> [] in + { + back_effect_info with + stateful = back_effect_info.stateful && b; + can_fail = back_effect_info.can_fail && b; + } in let state = if back_effect_info.stateful then [ (None, mk_state_ty) ] else [] @@ -1145,8 +1137,7 @@ let translate_fun_sig_with_regions_hierarchy_to_decomposed let inputs = inputs_no_state @ state in let output_names, outputs = compute_back_outputs_for_gid gid in let filter = - !Config.simplify_merged_fwd_backs - && !Config.return_back_funs && inputs = [] && outputs = [] + !Config.simplify_merged_fwd_backs && inputs = [] && outputs = [] in let info = { @@ -1186,7 +1177,7 @@ let translate_fun_sig_with_regions_hierarchy_to_decomposed } in let ignore_output = - if !Config.return_back_funs && !Config.simplify_merged_fwd_backs then + if !Config.simplify_merged_fwd_backs then ty_is_unit fwd_output && List.exists (fun (info : back_sg_info) -> not info.filter) @@ -1296,10 +1287,10 @@ let compute_back_tys (dsg : Pure.decomposed_fun_sig) (subst : (generic_args * trait_instance_id) option) : ty option list = List.map (Option.map snd) (compute_back_tys_with_info dsg subst) -(** In case we merge the fwd/back functions: compute the output type of - a function, from a decomposed signature. *) +(** Compute the output type of a function, from a decomposed signature + (the output type contains the type of the value returned by the forward + function as well as the types of the returned backward functions). *) let compute_output_ty_from_decomposed (dsg : Pure.decomposed_fun_sig) : ty = - assert !Config.return_back_funs; (* Compute the arrow types for all the backward functions *) let back_tys = List.filter_map (fun x -> x) (compute_back_tys dsg None) in (* Group the forward output and the types of the backward functions *) @@ -1315,8 +1306,8 @@ let compute_output_ty_from_decomposed (dsg : Pure.decomposed_fun_sig) : ty = in mk_output_ty_from_effect_info effect_info output -let translate_fun_sig_from_decomposed (dsg : Pure.decomposed_fun_sig) - (gid : RegionGroupId.id option) : fun_sig = +let translate_fun_sig_from_decomposed (dsg : Pure.decomposed_fun_sig) : fun_sig + = let generics = dsg.generics in let llbc_generics = dsg.llbc_generics in let preds = dsg.preds in @@ -1329,27 +1320,10 @@ let translate_fun_sig_from_decomposed (dsg : Pure.decomposed_fun_sig) (gid, info.effect_info)) (RegionGroupId.Map.bindings dsg.back_sg)) in - let mk_output_ty = mk_output_ty_from_effect_info in let inputs, output = - (* Two cases depending on whether we split the forward/backward functions or not *) - if !Config.return_back_funs then ( - assert (gid = None); - let output = compute_output_ty_from_decomposed dsg in - let inputs = dsg.fwd_inputs in - (inputs, output)) - else - match gid with - | None -> - let effect_info = dsg.fwd_info.effect_info in - let output = mk_output_ty effect_info dsg.fwd_output in - (dsg.fwd_inputs, output) - | Some gid -> - let back_sg = RegionGroupId.Map.find gid dsg.back_sg in - let effect_info = back_sg.effect_info in - let inputs = dsg.fwd_inputs @ List.map snd back_sg.inputs in - let output = mk_simpl_tuple_ty back_sg.outputs in - let output = mk_output_ty effect_info output in - (inputs, output) + let output = compute_output_ty_from_decomposed dsg in + let inputs = dsg.fwd_inputs in + (inputs, output) in { generics; llbc_generics; preds; inputs; output; fwd_info; back_effect_info } @@ -1933,16 +1907,14 @@ and translate_panic (ctx : bs_ctx) : texpression = *) match ctx.bid with | None -> - if !Config.return_back_funs then - let back_tys = compute_back_tys ctx.sg None in - let back_tys = List.filter_map (fun x -> x) back_tys in - let tys = - if ctx.sg.fwd_info.ignore_output then back_tys - else ctx.sg.fwd_output :: back_tys - in - let output = mk_simpl_tuple_ty tys in - mk_output output - else mk_output ctx.sg.fwd_output + let back_tys = compute_back_tys ctx.sg None in + let back_tys = List.filter_map (fun x -> x) back_tys in + let tys = + if ctx.sg.fwd_info.ignore_output then back_tys + else ctx.sg.fwd_output :: back_tys + in + let output = mk_simpl_tuple_ty tys in + mk_output output | Some bid -> let output = mk_simpl_tuple_ty (RegionGroupId.Map.find bid ctx.sg.back_sg).outputs @@ -2080,107 +2052,103 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) : (List.concat [ fuel; args; [ state_var ] ], ctx, Some nstate_var) else (List.concat [ fuel; args ], ctx, None) in - (* If we do not split the forward/backward functions: generate the - variables for the backward functions returned by the forward + (* Generate the variables for the backward functions returned by the forward function. *) let ctx, ignore_fwd_output, back_funs_map, back_funs = - if !Config.return_back_funs then ( - (* We need to compute the signatures of the backward functions. *) - let sg = Option.get call.sg in - let decls_ctx = ctx.decls_ctx in - let dsg = - translate_fun_sig_with_regions_hierarchy_to_decomposed decls_ctx - fid call.regions_hierarchy sg - (List.map (fun _ -> None) sg.inputs) - in - log#ldebug - (lazy ("dsg.generics:\n" ^ show_generic_params dsg.generics)); - let tr_self, all_generics = - match call.trait_method_generics with - | None -> (UnknownTrait __FUNCTION__, generics) - | Some (all_generics, tr_self) -> - let all_generics = - ctx_translate_fwd_generic_args ctx all_generics - in - let tr_self = - translate_fwd_trait_instance_id ctx.type_ctx.type_infos - tr_self + (* We need to compute the signatures of the backward functions. *) + let sg = Option.get call.sg in + let decls_ctx = ctx.decls_ctx in + let dsg = + translate_fun_sig_with_regions_hierarchy_to_decomposed decls_ctx fid + call.regions_hierarchy sg + (List.map (fun _ -> None) sg.inputs) + in + log#ldebug + (lazy ("dsg.generics:\n" ^ show_generic_params dsg.generics)); + let tr_self, all_generics = + match call.trait_method_generics with + | None -> (UnknownTrait __FUNCTION__, generics) + | Some (all_generics, tr_self) -> + let all_generics = + ctx_translate_fwd_generic_args ctx all_generics + in + let tr_self = + translate_fwd_trait_instance_id ctx.type_ctx.type_infos + tr_self + in + (tr_self, all_generics) + in + let back_tys = + compute_back_tys_with_info dsg (Some (all_generics, tr_self)) + in + (* Introduce variables for the backward functions *) + (* Compute a proper basename for the variables *) + let back_fun_name = + let name = + match fid with + | FunId (FAssumed fid) -> ( + match fid with + | BoxNew -> "box_new" + | BoxFree -> "box_free" + | ArrayRepeat -> "array_repeat" + | ArrayIndexShared -> "index_shared" + | ArrayIndexMut -> "index_mut" + | ArrayToSliceShared -> "to_slice_shared" + | ArrayToSliceMut -> "to_slice_mut" + | SliceIndexShared -> "index_shared" + | SliceIndexMut -> "index_mut") + | FunId (FRegular fid) | TraitMethod (_, _, fid) -> ( + let decl = + FunDeclId.Map.find fid ctx.fun_ctx.llbc_fun_decls in - (tr_self, all_generics) + match Collections.List.last decl.name with + | PeIdent (s, _) -> s + | PeImpl _ -> + (* We shouldn't get there *) + raise (Failure "Unexpected")) in - let back_tys = - compute_back_tys_with_info dsg (Some (all_generics, tr_self)) - in - (* Introduce variables for the backward functions *) - (* Compute a proper basename for the variables *) - let back_fun_name = - let name = - match fid with - | FunId (FAssumed fid) -> ( - match fid with - | BoxNew -> "box_new" - | BoxFree -> "box_free" - | ArrayRepeat -> "array_repeat" - | ArrayIndexShared -> "index_shared" - | ArrayIndexMut -> "index_mut" - | ArrayToSliceShared -> "to_slice_shared" - | ArrayToSliceMut -> "to_slice_mut" - | SliceIndexShared -> "index_shared" - | SliceIndexMut -> "index_mut") - | FunId (FRegular fid) | TraitMethod (_, _, fid) -> ( - let decl = - FunDeclId.Map.find fid ctx.fun_ctx.llbc_fun_decls - in - match Collections.List.last decl.name with - | PeIdent (s, _) -> s - | PeImpl _ -> - (* We shouldn't get there *) - raise (Failure "Unexpected")) - in - name ^ "_back" - in - let ctx, back_vars = - fresh_opt_vars - (List.map - (fun ty -> - match ty with - | None -> None - | Some (back_sg, ty) -> - (* We insert a name for the variable only if the function - can fail: if it can fail, it means the call returns a backward - function. Otherwise, we it directly returns the value given - back by the backward function, which means we shouldn't - give it a name like "back..." (it doesn't make sense) *) - let name = - if back_sg.effect_info.can_fail then - Some back_fun_name - else None - in - Some (name, ty)) - back_tys) - ctx - in - let back_funs = - List.filter_map - (fun v -> - match v with - | None -> None - | Some v -> Some (mk_typed_pattern_from_var v None)) - back_vars - in - let gids = - List.map - (fun (g : T.region_var_group) -> g.id) - call.regions_hierarchy - in - let back_vars = - List.map (Option.map mk_texpression_from_var) back_vars - in - let back_funs_map = - RegionGroupId.Map.of_list (List.combine gids back_vars) - in - (ctx, dsg.fwd_info.ignore_output, Some back_funs_map, back_funs)) - else (ctx, false, None, []) + name ^ "_back" + in + let ctx, back_vars = + fresh_opt_vars + (List.map + (fun ty -> + match ty with + | None -> None + | Some (back_sg, ty) -> + (* We insert a name for the variable only if the function + can fail: if it can fail, it means the call returns a backward + function. Otherwise, we it directly returns the value given + back by the backward function, which means we shouldn't + give it a name like "back..." (it doesn't make sense) *) + let name = + if back_sg.effect_info.can_fail then Some back_fun_name + else None + in + Some (name, ty)) + back_tys) + ctx + in + let back_funs = + List.filter_map + (fun v -> + match v with + | None -> None + | Some v -> Some (mk_typed_pattern_from_var v None)) + back_vars + in + let gids = + List.map + (fun (g : T.region_var_group) -> g.id) + call.regions_hierarchy + in + let back_vars = + List.map (Option.map mk_texpression_from_var) back_vars + in + let back_funs_map = + RegionGroupId.Map.of_list (List.combine gids back_vars) + in + (ctx, dsg.fwd_info.ignore_output, Some back_funs_map, back_funs) in (* Compute the pattern for the destination *) let ctx, dest = fresh_var_for_symbolic_value call.dest ctx in @@ -2407,19 +2375,6 @@ and translate_end_abstraction_fun_call (ectx : C.eval_ctx) (abs : V.abs) raise (Failure "Unreachable") in let effect_info = get_fun_effect_info ctx fun_id None (Some rg_id) in - let generics = ctx_translate_fwd_generic_args ctx call.generics in - (* Retrieve the original call and the parent abstractions *) - let _forward, backwards = get_abs_ancestors ctx abs call_id in - (* Retrieve the values consumed when we called the forward function and - * ended the parent backward functions: those give us part of the input - * values (rem: for now, as we disallow nested lifetimes, there can't be - * parent backward functions). - * Note that the forward inputs **include the fuel and the input state** - * (if we use those). *) - let fwd_inputs = call_info.forward_inputs in - let back_ancestors_inputs = - List.concat (List.map (fun (_abs, args) -> args) backwards) - in (* Retrieve the values consumed upon ending the loans inside this * abstraction: those give us the remaining input values *) let back_inputs = abs_to_consumed ctx ectx abs in @@ -2434,11 +2389,6 @@ and translate_end_abstraction_fun_call (ectx : C.eval_ctx) (abs : V.abs) ([ back_state ], ctx, Some nstate) else ([], ctx, None) in - (* Concatenate all the inpus *) - let inherited_inputs = - if !Config.return_back_funs then [] - else List.concat [ fwd_inputs; back_ancestors_inputs ] - in let back_inputs = List.append back_inputs back_state in (* Retrieve the values given back by this function: those are the output * values. We rely on the fact that there are no nested borrows to use the @@ -2459,58 +2409,33 @@ and translate_end_abstraction_fun_call (ectx : C.eval_ctx) (abs : V.abs) (* Retrieve the function id, and register the function call in the context if necessary.Arith_status *) let ctx, func = - bs_ctx_register_backward_call abs effect_info call_id rg_id inherited_inputs - back_inputs generics output.ty ctx + bs_ctx_register_backward_call abs call_id rg_id back_inputs ctx in (* Translate the next expression *) let next_e = translate_expression e ctx in (* Put everything together *) - let inputs = List.append inherited_inputs back_inputs in + let inputs = back_inputs in let args_mplaces = List.map (fun _ -> None) inputs in let args = List.map (fun (arg, mp) -> mk_opt_mplace_texpression mp arg) (List.combine inputs args_mplaces) in - (* **Optimization**: - ================= - We do a small optimization here if we split the forward/backward functions. - If the backward function doesn't have any output, we don't introduce any function - call. - See the comment in {!Config.filter_useless_monadic_calls}. - - TODO: use an option to disallow backward functions from updating the state. - TODO: a backward function which only gives back shared borrows shouldn't - update the state (state updates should only be used for mutable borrows, - with objects like Rc for instance). - *) - if - (not !Config.return_back_funs) - && !Config.filter_useless_monadic_calls - && outputs = [] && nstate = None - then ( - (* No outputs - we do a small sanity check: the backward function - should have exactly the same number of inputs as the forward: - this number can be different only if the forward function returned - a value containing mutable borrows, which can't be the case... *) - assert (List.length inputs = List.length fwd_inputs); - next_e) - else - (* The backward function might also have been filtered if we do not - split the forward/backward functions *) - match func with - | None -> next_e - | Some func -> - log#ldebug - (lazy - (let args = List.map (texpression_to_string ctx) args in - "func: " - ^ texpression_to_string ctx func - ^ "\nfunc type: " - ^ pure_ty_to_string ctx func.ty - ^ "\n\nargs:\n" ^ String.concat "\n" args)); - let call = mk_apps func args in - mk_let effect_info.can_fail output call next_e + (* The backward function might have been filtered it does nothing + (consumes unit and returns unit). *) + match func with + | None -> next_e + | Some func -> + log#ldebug + (lazy + (let args = List.map (texpression_to_string ctx) args in + "func: " + ^ texpression_to_string ctx func + ^ "\nfunc type: " + ^ pure_ty_to_string ctx func.ty + ^ "\n\nargs:\n" ^ String.concat "\n" args)); + let call = mk_apps func args in + mk_let effect_info.can_fail output call next_e and translate_end_abstraction_identity (ectx : C.eval_ctx) (abs : V.abs) (e : S.expression) (ctx : bs_ctx) : texpression = @@ -2637,10 +2562,7 @@ and translate_end_abstraction_loop (ectx : C.eval_ctx) (abs : V.abs) else ([], ctx, None) in (* Concatenate all the inputs *) - let inputs = - if !Config.return_back_funs then List.concat [ back_inputs; back_state ] - else List.concat [ fwd_inputs; back_inputs; back_state ] - in + let inputs = List.concat [ back_inputs; back_state ] in (* Retrieve the values given back by this function *) let ctx, outputs = abs_to_given_back None abs ctx in (* Group the output values together: first the updated inputs *) @@ -2670,77 +2592,46 @@ and translate_end_abstraction_loop (ectx : C.eval_ctx) (abs : V.abs) - or a call to the variable we introduced for the backward function, if we merge the forward/backward functions *) let func = - if !Config.return_back_funs then - RegionGroupId.Map.find rg_id (Option.get loop_info.back_funs) - else - let func_ty = mk_arrows input_tys ret_ty in - let func = Fun (FromLlbc (FunId fun_id, Some loop_id, Some rg_id)) in - let func = { id = FunOrOp func; generics } in - Some { e = Qualif func; ty = func_ty } + RegionGroupId.Map.find rg_id (Option.get loop_info.back_funs) in - (* **Optimization**: - ================= - We do a small optimization here in case we split the forward/backward - functions. - If the backward function doesn't have any output, we don't introduce - any function call. - See the comment in {!Config.filter_useless_monadic_calls}. - - TODO: use an option to disallow backward functions from updating the state. - TODO: a backward function which only gives back shared borrows shouldn't - update the state (state updates should only be used for mutable borrows, - with objects like Rc for instance). - *) - if - (not !Config.return_back_funs) - && !Config.filter_useless_monadic_calls - && outputs = [] && nstate = None - then ( - (* No outputs - we do a small sanity check: the backward function - should have exactly the same number of inputs as the forward: - this number can be different only if the forward function returned - a value containing mutable borrows, which can't be the case... *) - assert (List.length inputs = List.length fwd_inputs); - next_e) - else - (* In case we merge the fwd/back functions we filter the backward - functions elsewhere *) - match func with - | None -> next_e - | Some func -> - let call = mk_apps func args in - (* Add meta-information - this is slightly hacky: we look at the - values consumed by the abstraction (note that those come from - *before* we applied the fixed-point context) and use them to - guide the naming of the output vars. - - Also, we need to convert the backward outputs from patterns to - variables. - - Finally, in practice, this works well only for loop bodies: - we do this only in this case. - TODO: improve the heuristics, to give weight to the hints for - instance. - *) - let next_e = - if ctx.inside_loop then - let consumed_values = abs_to_consumed ctx ectx abs in - let var_values = List.combine outputs consumed_values in - let var_values = - List.filter_map - (fun (var, v) -> - match var.Pure.value with - | PatVar (var, _) -> Some (var, v) - | _ -> None) - var_values - in - let vars, values = List.split var_values in - mk_emeta_symbolic_assignments vars values next_e - else next_e - in + (* We may have filtered the backward function elsewhere if it doesn't + do anything (doesn't consume anything and doesn't return anything) *) + match func with + | None -> next_e + | Some func -> + let call = mk_apps func args in + (* Add meta-information - this is slightly hacky: we look at the + values consumed by the abstraction (note that those come from + *before* we applied the fixed-point context) and use them to + guide the naming of the output vars. + + Also, we need to convert the backward outputs from patterns to + variables. + + Finally, in practice, this works well only for loop bodies: + we do this only in this case. + TODO: improve the heuristics, to give weight to the hints for + instance. + *) + let next_e = + if ctx.inside_loop then + let consumed_values = abs_to_consumed ctx ectx abs in + let var_values = List.combine outputs consumed_values in + let var_values = + List.filter_map + (fun (var, v) -> + match var.Pure.value with + | PatVar (var, _) -> Some (var, v) + | _ -> None) + var_values + in + let vars, values = List.split var_values in + mk_emeta_symbolic_assignments vars values next_e + else next_e + in - (* Create the let-binding *) - mk_let effect_info.can_fail output call next_e) + (* Create the let-binding *) + mk_let effect_info.can_fail output call next_e) and translate_global_eval (gid : A.GlobalDeclId.id) (sval : V.symbolic_value) (e : S.expression) (ctx : bs_ctx) : texpression = @@ -3068,48 +2959,40 @@ and translate_forward_end (ectx : C.eval_ctx) *) let ctx = (* Introduce variables for the inputs and the state variable - and update the context. *) - if !Config.return_back_funs then - (* If the forward/backward functions are not split, we need - to introduce fresh variables for the additional inputs, - because they are locally introduced in a lambda *) - let back_sg = RegionGroupId.Map.find bid ctx.sg.back_sg in - let ctx, backward_inputs_no_state = - fresh_vars back_sg.inputs_no_state ctx - in - let ctx, backward_inputs_with_state = - if back_sg.effect_info.stateful then - let ctx, var, _ = bs_ctx_fresh_state_var ctx in - (ctx, backward_inputs_no_state @ [ var ]) - else (ctx, backward_inputs_no_state) - in - { - ctx with - backward_inputs_no_state = - RegionGroupId.Map.add bid backward_inputs_no_state - ctx.backward_inputs_no_state; - backward_inputs_with_state = - RegionGroupId.Map.add bid backward_inputs_with_state - ctx.backward_inputs_with_state; - } - else - (* Update the state variable *) - let back_state_var = - RegionGroupId.Map.find bid ctx.back_state_vars - in - { ctx with state_var = back_state_var } + and update the context. + + We need to introduce fresh variables for the additional inputs, + because they are locally introduced in a lambda. + *) + let back_sg = RegionGroupId.Map.find bid ctx.sg.back_sg in + let ctx, backward_inputs_no_state = + fresh_vars back_sg.inputs_no_state ctx + in + let ctx, backward_inputs_with_state = + if back_sg.effect_info.stateful then + let ctx, var, _ = bs_ctx_fresh_state_var ctx in + (ctx, backward_inputs_no_state @ [ var ]) + else (ctx, backward_inputs_no_state) + in + { + ctx with + backward_inputs_no_state = + RegionGroupId.Map.add bid backward_inputs_no_state + ctx.backward_inputs_no_state; + backward_inputs_with_state = + RegionGroupId.Map.add bid backward_inputs_with_state + ctx.backward_inputs_with_state; + } in let e = T.RegionGroupId.Map.find bid back_e in let finish e = (* Wrap in lambdas if necessary *) - if !Config.return_back_funs then - let inputs = - RegionGroupId.Map.find bid ctx.backward_inputs_with_state - in - let places = List.map (fun _ -> None) inputs in - mk_lambdas_from_vars inputs places e - else e + let inputs = + RegionGroupId.Map.find bid ctx.backward_inputs_with_state + in + let places = List.map (fun _ -> None) inputs in + mk_lambdas_from_vars inputs places e in (ctx, e, finish) in @@ -3131,85 +3014,83 @@ and translate_forward_end (ectx : C.eval_ctx) function, if needs be, and lookup the proper expression. *) let translate_end ctx = - if !Config.return_back_funs then - (* Compute the output of the forward function *) - let fwd_effect_info = ctx.sg.fwd_info.effect_info in - let ctx, pure_fwd_var = fresh_var None ctx.sg.fwd_output ctx in - let fwd_e = translate_one_end ctx None in - - (* Introduce the backward functions. *) - let back_el = - List.map - (fun ((gid, _) : RegionGroupId.id * back_sg_info) -> - translate_one_end ctx (Some gid)) - (RegionGroupId.Map.bindings ctx.sg.back_sg) - in + (* Compute the output of the forward function *) + let fwd_effect_info = ctx.sg.fwd_info.effect_info in + let ctx, pure_fwd_var = fresh_var None ctx.sg.fwd_output ctx in + let fwd_e = translate_one_end ctx None in - (* Compute whether the backward expressions should be evaluated straight - away or not (i.e., if we should bind them with monadic let-bindings - or not). We evaluate them straight away if they can fail and have no - inputs. *) - let evaluate_backs = - List.map - (fun (sg : back_sg_info) -> - if !Config.simplify_merged_fwd_backs then - sg.inputs = [] && sg.effect_info.can_fail - else false) - (RegionGroupId.Map.values ctx.sg.back_sg) - in + (* Introduce the backward functions. *) + let back_el = + List.map + (fun ((gid, _) : RegionGroupId.id * back_sg_info) -> + translate_one_end ctx (Some gid)) + (RegionGroupId.Map.bindings ctx.sg.back_sg) + in - (* Introduce variables for the backward functions. - We lookup the LLBC definition in an attempt to derive pretty names - for those functions. *) - let _, back_vars = fresh_back_vars_for_current_fun ctx in + (* Compute whether the backward expressions should be evaluated straight + away or not (i.e., if we should bind them with monadic let-bindings + or not). We evaluate them straight away if they can fail and have no + inputs. *) + let evaluate_backs = + List.map + (fun (sg : back_sg_info) -> + if !Config.simplify_merged_fwd_backs then + sg.inputs = [] && sg.effect_info.can_fail + else false) + (RegionGroupId.Map.values ctx.sg.back_sg) + in - (* Create the return expressions *) - let vars = - let back_vars = List.filter_map (fun x -> x) back_vars in - if ctx.sg.fwd_info.ignore_output then back_vars - else pure_fwd_var :: back_vars - in - let vars = List.map mk_texpression_from_var vars in - let ret = mk_simpl_tuple_texpression vars in - - (* Introduce a fresh input state variable for the forward expression *) - let _ctx, state_var, state_pat = - if fwd_effect_info.stateful then - let ctx, var, pat = bs_ctx_fresh_state_var ctx in - (ctx, [ var ], [ pat ]) - else (ctx, [], []) - in + (* Introduce variables for the backward functions. + We lookup the LLBC definition in an attempt to derive pretty names + for those functions. *) + let _, back_vars = fresh_back_vars_for_current_fun ctx in + + (* Create the return expressions *) + let vars = + let back_vars = List.filter_map (fun x -> x) back_vars in + if ctx.sg.fwd_info.ignore_output then back_vars + else pure_fwd_var :: back_vars + in + let vars = List.map mk_texpression_from_var vars in + let ret = mk_simpl_tuple_texpression vars in + + (* Introduce a fresh input state variable for the forward expression *) + let _ctx, state_var, state_pat = + if fwd_effect_info.stateful then + let ctx, var, pat = bs_ctx_fresh_state_var ctx in + (ctx, [ var ], [ pat ]) + else (ctx, [], []) + in - let state_var = List.map mk_texpression_from_var state_var in - let ret = mk_simpl_tuple_texpression (state_var @ [ ret ]) in - let ret = mk_result_return_texpression ret in + let state_var = List.map mk_texpression_from_var state_var in + let ret = mk_simpl_tuple_texpression (state_var @ [ ret ]) in + let ret = mk_result_return_texpression ret in - (* Introduce all the let-bindings *) + (* Introduce all the let-bindings *) - (* Combine: - - the backward variables - - whether we should evaluate the expression for the backward function - (i.e., should we use a monadic let-binding or not - we do if the - backward functions don't have inputs and can fail) - - the expressions for the backward functions - *) - let back_vars_els = - List.filter_map - (fun (v, (eval, el)) -> - match v with None -> None | Some v -> Some (v, eval, el)) - (List.combine back_vars (List.combine evaluate_backs back_el)) - in - let e = - List.fold_right - (fun (var, evaluate, back_e) e -> - mk_let evaluate (mk_typed_pattern_from_var var None) back_e e) - back_vars_els ret - in - (* Bind the expression for the forward output *) - let fwd_var = mk_typed_pattern_from_var pure_fwd_var None in - let pat = mk_simpl_tuple_pattern (state_pat @ [ fwd_var ]) in - mk_let fwd_effect_info.can_fail pat fwd_e e - else translate_one_end ctx ctx.bid + (* Combine: + - the backward variables + - whether we should evaluate the expression for the backward function + (i.e., should we use a monadic let-binding or not - we do if the + backward functions don't have inputs and can fail) + - the expressions for the backward functions + *) + let back_vars_els = + List.filter_map + (fun (v, (eval, el)) -> + match v with None -> None | Some v -> Some (v, eval, el)) + (List.combine back_vars (List.combine evaluate_backs back_el)) + in + let e = + List.fold_right + (fun (var, evaluate, back_e) e -> + mk_let evaluate (mk_typed_pattern_from_var var None) back_e e) + back_vars_els ret + in + (* Bind the expression for the forward output *) + let fwd_var = mk_typed_pattern_from_var pure_fwd_var None in + let pat = mk_simpl_tuple_pattern (state_pat @ [ fwd_var ]) in + mk_let fwd_effect_info.can_fail pat fwd_e e in (* If we are (re-)entering a loop, we need to introduce a call to the @@ -3279,24 +3160,22 @@ and translate_forward_end (ectx : C.eval_ctx) backward functions of the outer function. *) let ctx, back_funs_map, back_funs = - if !Config.return_back_funs then - let ctx, back_vars = fresh_back_vars_for_current_fun ctx in - let back_funs = - List.filter_map - (fun v -> - match v with - | None -> None - | Some v -> Some (mk_typed_pattern_from_var v None)) - back_vars - in - let gids = RegionGroupId.Map.keys ctx.sg.back_sg in - let back_funs_map = - RegionGroupId.Map.of_list - (List.combine gids - (List.map (Option.map mk_texpression_from_var) back_vars)) - in - (ctx, Some back_funs_map, back_funs) - else (ctx, None, []) + let ctx, back_vars = fresh_back_vars_for_current_fun ctx in + let back_funs = + List.filter_map + (fun v -> + match v with + | None -> None + | Some v -> Some (mk_typed_pattern_from_var v None)) + back_vars + in + let gids = RegionGroupId.Map.keys ctx.sg.back_sg in + let back_funs_map = + RegionGroupId.Map.of_list + (List.combine gids + (List.map (Option.map mk_texpression_from_var) back_vars)) + in + (ctx, Some back_funs_map, back_funs) in (* Introduce patterns *) @@ -3438,91 +3317,58 @@ and translate_loop (loop : S.loop) (ctx : bs_ctx) : texpression = (* The output type of the loop function *) let fwd_effect_info = { ctx.sg.fwd_info.effect_info with is_rec = true } in let back_effect_infos, output_ty = - if !Config.return_back_funs then - (* The loop backward functions consume the same additional inputs as the parent - function, but have custom outputs *) - let back_sgs = RegionGroupId.Map.bindings ctx.sg.back_sg in - let given_back_tys = RegionGroupId.Map.values rg_to_given_back_tys in - let back_info_tys = - List.map - (fun (((id, back_sg), given_back) : (_ * back_sg_info) * ty list) -> - (* Remark: the effect info of the backward function for the loop - is almost the same as for the backward function of the parent function. - Quite importantly, the fact that the function is stateful and/or can fail - mostly depends on whether it has inputs or not, and the backward functions - for the loops have the same inputs as the backward functions for the parent - function. - *) - let effect_info = back_sg.effect_info in - let effect_info = { effect_info with is_rec = true } in - (* Compute the input/output types *) - let inputs = List.map snd back_sg.inputs in - let outputs = given_back in - (* Filter if necessary *) - let ty = - if - !Config.simplify_merged_fwd_backs && inputs = [] && outputs = [] - then None - else - let output = mk_simpl_tuple_ty outputs in - let output = - mk_back_output_ty_from_effect_info effect_info inputs output - in - let ty = mk_arrows inputs output in - Some ty - in - ((id, effect_info), ty)) - (List.combine back_sgs given_back_tys) - in - let back_info = List.map fst back_info_tys in - let back_info = RegionGroupId.Map.of_list back_info in - let back_tys = List.filter_map snd back_info_tys in - let output = - if ctx.sg.fwd_info.ignore_output then back_tys - else ctx.sg.fwd_output :: back_tys - in - let output = mk_simpl_tuple_ty output in - let effect_info = ctx.sg.fwd_info.effect_info in - let output = - if effect_info.stateful then mk_simpl_tuple_ty [ mk_state_ty; output ] - else output - in - let output = - if effect_info.can_fail && inputs <> [] then mk_result_ty output - else output - in - (back_info, output) - else - let back_info = - RegionGroupId.Map.of_list - (List.map - (fun ((id, back_sg) : _ * back_sg_info) -> - (id, { back_sg.effect_info with is_rec = true })) - (RegionGroupId.Map.bindings ctx.sg.back_sg)) - in - let output = - match ctx.bid with - | None -> - (* Forward function: same type as the parent function *) - (translate_fun_sig_from_decomposed ctx.sg None).output - | Some rg_id -> - (* Backward function: custom return type *) - let doutputs = - T.RegionGroupId.Map.find rg_id rg_to_given_back_tys - in - let output = mk_simpl_tuple_ty doutputs in - let fwd_effect_info = ctx.sg.fwd_info.effect_info in - let output = - if fwd_effect_info.stateful then - mk_simpl_tuple_ty [ mk_state_ty; output ] - else output - in - let output = - if fwd_effect_info.can_fail then mk_result_ty output else output - in - output - in - (back_info, output) + (* The loop backward functions consume the same additional inputs as the parent + function, but have custom outputs *) + let back_sgs = RegionGroupId.Map.bindings ctx.sg.back_sg in + let given_back_tys = RegionGroupId.Map.values rg_to_given_back_tys in + let back_info_tys = + List.map + (fun (((id, back_sg), given_back) : (_ * back_sg_info) * ty list) -> + (* Remark: the effect info of the backward function for the loop + is almost the same as for the backward function of the parent function. + Quite importantly, the fact that the function is stateful and/or can fail + mostly depends on whether it has inputs or not, and the backward functions + for the loops have the same inputs as the backward functions for the parent + function. + *) + let effect_info = back_sg.effect_info in + let effect_info = { effect_info with is_rec = true } in + (* Compute the input/output types *) + let inputs = List.map snd back_sg.inputs in + let outputs = given_back in + (* Filter if necessary *) + let ty = + if !Config.simplify_merged_fwd_backs && inputs = [] && outputs = [] + then None + else + let output = mk_simpl_tuple_ty outputs in + let output = + mk_back_output_ty_from_effect_info effect_info inputs output + in + let ty = mk_arrows inputs output in + Some ty + in + ((id, effect_info), ty)) + (List.combine back_sgs given_back_tys) + in + let back_info = List.map fst back_info_tys in + let back_info = RegionGroupId.Map.of_list back_info in + let back_tys = List.filter_map snd back_info_tys in + let output = + if ctx.sg.fwd_info.ignore_output then back_tys + else ctx.sg.fwd_output :: back_tys + in + let output = mk_simpl_tuple_ty output in + let effect_info = ctx.sg.fwd_info.effect_info in + let output = + if effect_info.stateful then mk_simpl_tuple_ty [ mk_state_ty; output ] + else output + in + let output = + if effect_info.can_fail && inputs <> [] then mk_result_ty output + else output + in + (back_info, output) in (* Add the loop information in the context *) @@ -3708,21 +3554,19 @@ let wrap_in_match_fuel (fuel0 : VarId.id) (fuel : VarId.id) (body : texpression) let translate_fun_decl (ctx : bs_ctx) (body : S.expression option) : fun_decl = (* Translate *) let def = ctx.fun_decl in - let bid = ctx.bid in + assert (ctx.bid = None); log#ldebug (lazy ("SymbolicToPure.translate_fun_decl: " ^ name_to_string ctx def.name - ^ " (" - ^ Print.option_to_string T.RegionGroupId.to_string bid - ^ ")\n")); + ^ "\n")); (* Translate the declaration *) let def_id = def.def_id in let llbc_name = def.name in let name = name_to_string ctx llbc_name in (* Translate the signature *) - let signature = translate_fun_sig_from_decomposed ctx.sg ctx.bid in + let signature = translate_fun_sig_from_decomposed ctx.sg in let regions_hierarchy = FunIdMap.find (FRegular def_id) ctx.fun_ctx.regions_hierarchies in @@ -3732,7 +3576,7 @@ let translate_fun_decl (ctx : bs_ctx) (body : S.expression option) : fun_decl = | None -> None | Some body -> let effect_info = - get_fun_effect_info ctx (FunId (FRegular def_id)) None bid + get_fun_effect_info ctx (FunId (FRegular def_id)) None None in let body = translate_expression body ctx in (* Add a match over the fuel, if necessary *) @@ -3760,37 +3604,8 @@ let translate_fun_decl (ctx : bs_ctx) (body : S.expression option) : fun_decl = if effect_info.stateful_group then [ mk_state_var ctx.state_var ] else [] in - (* Compute the list of (properly ordered) backward input variables *) - let backward_inputs : var list = - match bid with - | None -> [] - | Some back_id -> - assert (not !Config.return_back_funs); - let parents_ids = - list_ordered_ancestor_region_groups regions_hierarchy back_id - in - let backward_ids = List.append parents_ids [ back_id ] in - List.concat - (List.map - (fun id -> - T.RegionGroupId.Map.find id ctx.backward_inputs_no_state) - backward_ids) - in - (* Introduce the backward input state (the state at call site of the - * *backward* function), if necessary *) - let back_state = - if effect_info.stateful && Option.is_some bid then - let state_var = - RegionGroupId.Map.find (Option.get bid) ctx.back_state_vars - in - [ mk_state_var state_var ] - else [] - in (* Group the inputs together *) - let inputs = - List.concat - [ fuel; ctx.forward_inputs; fwd_state; backward_inputs; back_state ] - in + let inputs = List.concat [ fuel; ctx.forward_inputs; fwd_state ] in let inputs_lvs = List.map (fun v -> mk_typed_pattern_from_var v None) inputs in @@ -3799,16 +3614,10 @@ let translate_fun_decl (ctx : bs_ctx) (body : S.expression option) : fun_decl = (lazy ("SymbolicToPure.translate_fun_decl: " ^ name_to_string ctx def.name - ^ " (" - ^ Print.option_to_string T.RegionGroupId.to_string bid - ^ ")" ^ "\n- forward_inputs: " + ^ "\n- inputs: " ^ String.concat ", " (List.map show_var ctx.forward_inputs) - ^ "\n- fwd_state: " + ^ "\n- state: " ^ String.concat ", " (List.map show_var fwd_state) - ^ "\n- backward_inputs: " - ^ String.concat ", " (List.map show_var backward_inputs) - ^ "\n- back_state: " - ^ String.concat ", " (List.map show_var back_state) ^ "\n- signature.inputs: " ^ String.concat ", " (List.map (pure_ty_to_string ctx) signature.inputs))); @@ -3837,7 +3646,6 @@ let translate_fun_decl (ctx : bs_ctx) (body : S.expression option) : fun_decl = kind = def.kind; num_loops; loop_id; - back_id = bid; llbc_name; name; signature; |