summaryrefslogtreecommitdiff
path: root/compiler/SymbolicToPure.ml
diff options
context:
space:
mode:
authorSon HO2024-03-08 16:51:40 +0100
committerGitHub2024-03-08 16:51:40 +0100
commit169d011cbfa83d853d0148bbf6b946e6ccbe4c4c (patch)
treeed8953634d14313d5b7d6ad204343d64eb990baf /compiler/SymbolicToPure.ml
parentb604bb9935007a1f0e9c7f556f8196f0e14c85ce (diff)
parent873deb005b394aca3090497e6c21ab9f8c2676be (diff)
Merge pull request #83 from AeneasVerif/son/backs
Remove the option to split the forward/backward functions
Diffstat (limited to '')
-rw-r--r--compiler/SymbolicToPure.ml949
1 files changed, 374 insertions, 575 deletions
diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml
index 3a50e495..2db5f66c 100644
--- a/compiler/SymbolicToPure.ml
+++ b/compiler/SymbolicToPure.ml
@@ -805,11 +805,9 @@ let bs_ctx_register_forward_call (call_id : V.FunCallId.id) (forward : S.call)
that we need to call. This function may be [None] if it has to be ignored
(because it does nothing).
*)
-let bs_ctx_register_backward_call (abs : V.abs) (effect_info : fun_effect_info)
- (call_id : V.FunCallId.id) (back_id : T.RegionGroupId.id)
- (inherited_args : texpression list) (back_args : texpression list)
- (generics : generic_args) (output_ty : ty) (ctx : bs_ctx) :
- bs_ctx * texpression option =
+let bs_ctx_register_backward_call (abs : V.abs) (call_id : V.FunCallId.id)
+ (back_id : T.RegionGroupId.id) (back_args : texpression list) (ctx : bs_ctx)
+ : bs_ctx * texpression option =
(* Insert the abstraction in the call informations *)
let info = V.FunCallId.Map.find call_id ctx.calls in
let calls = V.FunCallId.Map.add call_id info ctx.calls in
@@ -819,29 +817,9 @@ let bs_ctx_register_backward_call (abs : V.abs) (effect_info : fun_effect_info)
let abstractions =
V.AbstractionId.Map.add abs.abs_id (abs, back_args) abstractions
in
- (* Compute the expression corresponding to the function *)
- let func =
- if !Config.return_back_funs then
- (* Lookup the variable introduced for the backward function *)
- RegionGroupId.Map.find back_id (Option.get info.back_funs)
- else
- (* Retrieve the fun_id *)
- let fun_id =
- match info.forward.call_id with
- | S.Fun (fid, _) ->
- let fid = translate_fun_id_or_trait_method_ref ctx fid in
- Fun (FromLlbc (fid, None, Some back_id))
- | S.Unop _ | S.Binop _ -> raise (Failure "Unreachable")
- in
- let args = List.append inherited_args back_args in
- let input_tys = (List.map (fun (x : texpression) -> x.ty)) args in
- let ret_ty =
- if effect_info.can_fail then mk_result_ty output_ty else output_ty
- in
- let func_ty = mk_arrows input_tys ret_ty in
- let func = { id = FunOrOp fun_id; generics } in
- Some { e = Qualif func; ty = func_ty }
- in
+ (* Compute the expression corresponding to the function.
+ We simply lookup the variable introduced for the backward function. *)
+ let func = RegionGroupId.Map.find back_id (Option.get info.back_funs) in
(* Update the context and return *)
({ ctx with calls; abstractions }, func)
@@ -1124,20 +1102,34 @@ let translate_fun_sig_with_regions_hierarchy_to_decomposed
let inputs_no_state =
List.map (fun ty -> (Some "ret", ty)) inputs_no_state
in
- (* In case we merge the forward/backward functions:
- we consider the backward function as stateful and potentially failing
+ (* We consider a backward function as stateful and potentially failing
**only if it has inputs** (for the "potentially failing": if it has
not inputs, we directly evaluate it in the body of the forward function).
+
+ For instance, we do the following:
+ {[
+ // Rust
+ fn push<T, 'a>(v : &mut Vec<T>, x : T) { ... }
+
+ (* Generated code: before doing unit elimination.
+ We return (), as well as the backward function; as the backward
+ function doesn't consume any inputs, it is a value that we compute
+ directly in the body of [push].
+ *)
+ let push T (v : Vec T) (x : T) : Result (() * Vec T) = ...
+
+ (* Generated code: after doing unit elimination, if we simplify the merged
+ fwd/back functions (see below). *)
+ let push T (v : Vec T) (x : T) : Result (Vec T) = ...
+ ]}
*)
let back_effect_info =
- if !Config.return_back_funs then
- let b = inputs_no_state <> [] in
- {
- back_effect_info with
- stateful = back_effect_info.stateful && b;
- can_fail = back_effect_info.can_fail && b;
- }
- else back_effect_info
+ let b = inputs_no_state <> [] in
+ {
+ back_effect_info with
+ stateful = back_effect_info.stateful && b;
+ can_fail = back_effect_info.can_fail && b;
+ }
in
let state =
if back_effect_info.stateful then [ (None, mk_state_ty) ] else []
@@ -1145,8 +1137,7 @@ let translate_fun_sig_with_regions_hierarchy_to_decomposed
let inputs = inputs_no_state @ state in
let output_names, outputs = compute_back_outputs_for_gid gid in
let filter =
- !Config.simplify_merged_fwd_backs
- && !Config.return_back_funs && inputs = [] && outputs = []
+ !Config.simplify_merged_fwd_backs && inputs = [] && outputs = []
in
let info =
{
@@ -1186,7 +1177,7 @@ let translate_fun_sig_with_regions_hierarchy_to_decomposed
}
in
let ignore_output =
- if !Config.return_back_funs && !Config.simplify_merged_fwd_backs then
+ if !Config.simplify_merged_fwd_backs then
ty_is_unit fwd_output
&& List.exists
(fun (info : back_sg_info) -> not info.filter)
@@ -1296,10 +1287,10 @@ let compute_back_tys (dsg : Pure.decomposed_fun_sig)
(subst : (generic_args * trait_instance_id) option) : ty option list =
List.map (Option.map snd) (compute_back_tys_with_info dsg subst)
-(** In case we merge the fwd/back functions: compute the output type of
- a function, from a decomposed signature. *)
+(** Compute the output type of a function, from a decomposed signature
+ (the output type contains the type of the value returned by the forward
+ function as well as the types of the returned backward functions). *)
let compute_output_ty_from_decomposed (dsg : Pure.decomposed_fun_sig) : ty =
- assert !Config.return_back_funs;
(* Compute the arrow types for all the backward functions *)
let back_tys = List.filter_map (fun x -> x) (compute_back_tys dsg None) in
(* Group the forward output and the types of the backward functions *)
@@ -1315,8 +1306,8 @@ let compute_output_ty_from_decomposed (dsg : Pure.decomposed_fun_sig) : ty =
in
mk_output_ty_from_effect_info effect_info output
-let translate_fun_sig_from_decomposed (dsg : Pure.decomposed_fun_sig)
- (gid : RegionGroupId.id option) : fun_sig =
+let translate_fun_sig_from_decomposed (dsg : Pure.decomposed_fun_sig) : fun_sig
+ =
let generics = dsg.generics in
let llbc_generics = dsg.llbc_generics in
let preds = dsg.preds in
@@ -1329,27 +1320,10 @@ let translate_fun_sig_from_decomposed (dsg : Pure.decomposed_fun_sig)
(gid, info.effect_info))
(RegionGroupId.Map.bindings dsg.back_sg))
in
- let mk_output_ty = mk_output_ty_from_effect_info in
let inputs, output =
- (* Two cases depending on whether we split the forward/backward functions or not *)
- if !Config.return_back_funs then (
- assert (gid = None);
- let output = compute_output_ty_from_decomposed dsg in
- let inputs = dsg.fwd_inputs in
- (inputs, output))
- else
- match gid with
- | None ->
- let effect_info = dsg.fwd_info.effect_info in
- let output = mk_output_ty effect_info dsg.fwd_output in
- (dsg.fwd_inputs, output)
- | Some gid ->
- let back_sg = RegionGroupId.Map.find gid dsg.back_sg in
- let effect_info = back_sg.effect_info in
- let inputs = dsg.fwd_inputs @ List.map snd back_sg.inputs in
- let output = mk_simpl_tuple_ty back_sg.outputs in
- let output = mk_output_ty effect_info output in
- (inputs, output)
+ let output = compute_output_ty_from_decomposed dsg in
+ let inputs = dsg.fwd_inputs in
+ (inputs, output)
in
{ generics; llbc_generics; preds; inputs; output; fwd_info; back_effect_info }
@@ -1933,16 +1907,14 @@ and translate_panic (ctx : bs_ctx) : texpression =
*)
match ctx.bid with
| None ->
- if !Config.return_back_funs then
- let back_tys = compute_back_tys ctx.sg None in
- let back_tys = List.filter_map (fun x -> x) back_tys in
- let tys =
- if ctx.sg.fwd_info.ignore_output then back_tys
- else ctx.sg.fwd_output :: back_tys
- in
- let output = mk_simpl_tuple_ty tys in
- mk_output output
- else mk_output ctx.sg.fwd_output
+ let back_tys = compute_back_tys ctx.sg None in
+ let back_tys = List.filter_map (fun x -> x) back_tys in
+ let tys =
+ if ctx.sg.fwd_info.ignore_output then back_tys
+ else ctx.sg.fwd_output :: back_tys
+ in
+ let output = mk_simpl_tuple_ty tys in
+ mk_output output
| Some bid ->
let output =
mk_simpl_tuple_ty (RegionGroupId.Map.find bid ctx.sg.back_sg).outputs
@@ -2063,7 +2035,7 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) :
| S.Fun (fid, call_id) ->
(* Regular function call *)
let fid_t = translate_fun_id_or_trait_method_ref ctx fid in
- let func = Fun (FromLlbc (fid_t, None, None)) in
+ let func = Fun (FromLlbc (fid_t, None)) in
(* Retrieve the effect information about this function (can fail,
* takes a state as input, etc.) *)
let effect_info = get_fun_effect_info ctx fid None None in
@@ -2080,107 +2052,103 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) :
(List.concat [ fuel; args; [ state_var ] ], ctx, Some nstate_var)
else (List.concat [ fuel; args ], ctx, None)
in
- (* If we do not split the forward/backward functions: generate the
- variables for the backward functions returned by the forward
+ (* Generate the variables for the backward functions returned by the forward
function. *)
let ctx, ignore_fwd_output, back_funs_map, back_funs =
- if !Config.return_back_funs then (
- (* We need to compute the signatures of the backward functions. *)
- let sg = Option.get call.sg in
- let decls_ctx = ctx.decls_ctx in
- let dsg =
- translate_fun_sig_with_regions_hierarchy_to_decomposed decls_ctx
- fid call.regions_hierarchy sg
- (List.map (fun _ -> None) sg.inputs)
- in
- log#ldebug
- (lazy ("dsg.generics:\n" ^ show_generic_params dsg.generics));
- let tr_self, all_generics =
- match call.trait_method_generics with
- | None -> (UnknownTrait __FUNCTION__, generics)
- | Some (all_generics, tr_self) ->
- let all_generics =
- ctx_translate_fwd_generic_args ctx all_generics
- in
- let tr_self =
- translate_fwd_trait_instance_id ctx.type_ctx.type_infos
- tr_self
+ (* We need to compute the signatures of the backward functions. *)
+ let sg = Option.get call.sg in
+ let decls_ctx = ctx.decls_ctx in
+ let dsg =
+ translate_fun_sig_with_regions_hierarchy_to_decomposed decls_ctx fid
+ call.regions_hierarchy sg
+ (List.map (fun _ -> None) sg.inputs)
+ in
+ log#ldebug
+ (lazy ("dsg.generics:\n" ^ show_generic_params dsg.generics));
+ let tr_self, all_generics =
+ match call.trait_method_generics with
+ | None -> (UnknownTrait __FUNCTION__, generics)
+ | Some (all_generics, tr_self) ->
+ let all_generics =
+ ctx_translate_fwd_generic_args ctx all_generics
+ in
+ let tr_self =
+ translate_fwd_trait_instance_id ctx.type_ctx.type_infos
+ tr_self
+ in
+ (tr_self, all_generics)
+ in
+ let back_tys =
+ compute_back_tys_with_info dsg (Some (all_generics, tr_self))
+ in
+ (* Introduce variables for the backward functions *)
+ (* Compute a proper basename for the variables *)
+ let back_fun_name =
+ let name =
+ match fid with
+ | FunId (FAssumed fid) -> (
+ match fid with
+ | BoxNew -> "box_new"
+ | BoxFree -> "box_free"
+ | ArrayRepeat -> "array_repeat"
+ | ArrayIndexShared -> "index_shared"
+ | ArrayIndexMut -> "index_mut"
+ | ArrayToSliceShared -> "to_slice_shared"
+ | ArrayToSliceMut -> "to_slice_mut"
+ | SliceIndexShared -> "index_shared"
+ | SliceIndexMut -> "index_mut")
+ | FunId (FRegular fid) | TraitMethod (_, _, fid) -> (
+ let decl =
+ FunDeclId.Map.find fid ctx.fun_ctx.llbc_fun_decls
in
- (tr_self, all_generics)
+ match Collections.List.last decl.name with
+ | PeIdent (s, _) -> s
+ | PeImpl _ ->
+ (* We shouldn't get there *)
+ raise (Failure "Unexpected"))
in
- let back_tys =
- compute_back_tys_with_info dsg (Some (all_generics, tr_self))
- in
- (* Introduce variables for the backward functions *)
- (* Compute a proper basename for the variables *)
- let back_fun_name =
- let name =
- match fid with
- | FunId (FAssumed fid) -> (
- match fid with
- | BoxNew -> "box_new"
- | BoxFree -> "box_free"
- | ArrayRepeat -> "array_repeat"
- | ArrayIndexShared -> "index_shared"
- | ArrayIndexMut -> "index_mut"
- | ArrayToSliceShared -> "to_slice_shared"
- | ArrayToSliceMut -> "to_slice_mut"
- | SliceIndexShared -> "index_shared"
- | SliceIndexMut -> "index_mut")
- | FunId (FRegular fid) | TraitMethod (_, _, fid) -> (
- let decl =
- FunDeclId.Map.find fid ctx.fun_ctx.llbc_fun_decls
- in
- match Collections.List.last decl.name with
- | PeIdent (s, _) -> s
- | PeImpl _ ->
- (* We shouldn't get there *)
- raise (Failure "Unexpected"))
- in
- name ^ "_back"
- in
- let ctx, back_vars =
- fresh_opt_vars
- (List.map
- (fun ty ->
- match ty with
- | None -> None
- | Some (back_sg, ty) ->
- (* We insert a name for the variable only if the function
- can fail: if it can fail, it means the call returns a backward
- function. Otherwise, we it directly returns the value given
- back by the backward function, which means we shouldn't
- give it a name like "back..." (it doesn't make sense) *)
- let name =
- if back_sg.effect_info.can_fail then
- Some back_fun_name
- else None
- in
- Some (name, ty))
- back_tys)
- ctx
- in
- let back_funs =
- List.filter_map
- (fun v ->
- match v with
- | None -> None
- | Some v -> Some (mk_typed_pattern_from_var v None))
- back_vars
- in
- let gids =
- List.map
- (fun (g : T.region_var_group) -> g.id)
- call.regions_hierarchy
- in
- let back_vars =
- List.map (Option.map mk_texpression_from_var) back_vars
- in
- let back_funs_map =
- RegionGroupId.Map.of_list (List.combine gids back_vars)
- in
- (ctx, dsg.fwd_info.ignore_output, Some back_funs_map, back_funs))
- else (ctx, false, None, [])
+ name ^ "_back"
+ in
+ let ctx, back_vars =
+ fresh_opt_vars
+ (List.map
+ (fun ty ->
+ match ty with
+ | None -> None
+ | Some (back_sg, ty) ->
+ (* We insert a name for the variable only if the function
+ can fail: if it can fail, it means the call returns a backward
+ function. Otherwise, we it directly returns the value given
+ back by the backward function, which means we shouldn't
+ give it a name like "back..." (it doesn't make sense) *)
+ let name =
+ if back_sg.effect_info.can_fail then Some back_fun_name
+ else None
+ in
+ Some (name, ty))
+ back_tys)
+ ctx
+ in
+ let back_funs =
+ List.filter_map
+ (fun v ->
+ match v with
+ | None -> None
+ | Some v -> Some (mk_typed_pattern_from_var v None))
+ back_vars
+ in
+ let gids =
+ List.map
+ (fun (g : T.region_var_group) -> g.id)
+ call.regions_hierarchy
+ in
+ let back_vars =
+ List.map (Option.map mk_texpression_from_var) back_vars
+ in
+ let back_funs_map =
+ RegionGroupId.Map.of_list (List.combine gids back_vars)
+ in
+ (ctx, dsg.fwd_info.ignore_output, Some back_funs_map, back_funs)
in
(* Compute the pattern for the destination *)
let ctx, dest = fresh_var_for_symbolic_value call.dest ctx in
@@ -2407,19 +2375,6 @@ and translate_end_abstraction_fun_call (ectx : C.eval_ctx) (abs : V.abs)
raise (Failure "Unreachable")
in
let effect_info = get_fun_effect_info ctx fun_id None (Some rg_id) in
- let generics = ctx_translate_fwd_generic_args ctx call.generics in
- (* Retrieve the original call and the parent abstractions *)
- let _forward, backwards = get_abs_ancestors ctx abs call_id in
- (* Retrieve the values consumed when we called the forward function and
- * ended the parent backward functions: those give us part of the input
- * values (rem: for now, as we disallow nested lifetimes, there can't be
- * parent backward functions).
- * Note that the forward inputs **include the fuel and the input state**
- * (if we use those). *)
- let fwd_inputs = call_info.forward_inputs in
- let back_ancestors_inputs =
- List.concat (List.map (fun (_abs, args) -> args) backwards)
- in
(* Retrieve the values consumed upon ending the loans inside this
* abstraction: those give us the remaining input values *)
let back_inputs = abs_to_consumed ctx ectx abs in
@@ -2434,11 +2389,6 @@ and translate_end_abstraction_fun_call (ectx : C.eval_ctx) (abs : V.abs)
([ back_state ], ctx, Some nstate)
else ([], ctx, None)
in
- (* Concatenate all the inpus *)
- let inherited_inputs =
- if !Config.return_back_funs then []
- else List.concat [ fwd_inputs; back_ancestors_inputs ]
- in
let back_inputs = List.append back_inputs back_state in
(* Retrieve the values given back by this function: those are the output
* values. We rely on the fact that there are no nested borrows to use the
@@ -2459,58 +2409,33 @@ and translate_end_abstraction_fun_call (ectx : C.eval_ctx) (abs : V.abs)
(* Retrieve the function id, and register the function call in the context
if necessary.Arith_status *)
let ctx, func =
- bs_ctx_register_backward_call abs effect_info call_id rg_id inherited_inputs
- back_inputs generics output.ty ctx
+ bs_ctx_register_backward_call abs call_id rg_id back_inputs ctx
in
(* Translate the next expression *)
let next_e = translate_expression e ctx in
(* Put everything together *)
- let inputs = List.append inherited_inputs back_inputs in
+ let inputs = back_inputs in
let args_mplaces = List.map (fun _ -> None) inputs in
let args =
List.map
(fun (arg, mp) -> mk_opt_mplace_texpression mp arg)
(List.combine inputs args_mplaces)
in
- (* **Optimization**:
- =================
- We do a small optimization here if we split the forward/backward functions.
- If the backward function doesn't have any output, we don't introduce any function
- call.
- See the comment in {!Config.filter_useless_monadic_calls}.
-
- TODO: use an option to disallow backward functions from updating the state.
- TODO: a backward function which only gives back shared borrows shouldn't
- update the state (state updates should only be used for mutable borrows,
- with objects like Rc for instance).
- *)
- if
- (not !Config.return_back_funs)
- && !Config.filter_useless_monadic_calls
- && outputs = [] && nstate = None
- then (
- (* No outputs - we do a small sanity check: the backward function
- should have exactly the same number of inputs as the forward:
- this number can be different only if the forward function returned
- a value containing mutable borrows, which can't be the case... *)
- assert (List.length inputs = List.length fwd_inputs);
- next_e)
- else
- (* The backward function might also have been filtered if we do not
- split the forward/backward functions *)
- match func with
- | None -> next_e
- | Some func ->
- log#ldebug
- (lazy
- (let args = List.map (texpression_to_string ctx) args in
- "func: "
- ^ texpression_to_string ctx func
- ^ "\nfunc type: "
- ^ pure_ty_to_string ctx func.ty
- ^ "\n\nargs:\n" ^ String.concat "\n" args));
- let call = mk_apps func args in
- mk_let effect_info.can_fail output call next_e
+ (* The backward function might have been filtered it does nothing
+ (consumes unit and returns unit). *)
+ match func with
+ | None -> next_e
+ | Some func ->
+ log#ldebug
+ (lazy
+ (let args = List.map (texpression_to_string ctx) args in
+ "func: "
+ ^ texpression_to_string ctx func
+ ^ "\nfunc type: "
+ ^ pure_ty_to_string ctx func.ty
+ ^ "\n\nargs:\n" ^ String.concat "\n" args));
+ let call = mk_apps func args in
+ mk_let effect_info.can_fail output call next_e
and translate_end_abstraction_identity (ectx : C.eval_ctx) (abs : V.abs)
(e : S.expression) (ctx : bs_ctx) : texpression =
@@ -2614,8 +2539,6 @@ and translate_end_abstraction_loop (ectx : C.eval_ctx) (abs : V.abs)
get_fun_effect_info ctx (FunId fun_id) (Some vloop_id) (Some rg_id)
in
let loop_info = LoopId.Map.find loop_id ctx.loops in
- let generics = loop_info.generics in
- let fwd_inputs = Option.get loop_info.forward_inputs in
(* Retrieve the additional backward inputs. Note that those are actually
the backward inputs of the function we are synthesizing (and that we
need to *transmit* to the loop backward function): they are not the
@@ -2637,10 +2560,7 @@ and translate_end_abstraction_loop (ectx : C.eval_ctx) (abs : V.abs)
else ([], ctx, None)
in
(* Concatenate all the inputs *)
- let inputs =
- if !Config.return_back_funs then List.concat [ back_inputs; back_state ]
- else List.concat [ fwd_inputs; back_inputs; back_state ]
- in
+ let inputs = List.concat [ back_inputs; back_state ] in
(* Retrieve the values given back by this function *)
let ctx, outputs = abs_to_given_back None abs ctx in
(* Group the output values together: first the updated inputs *)
@@ -2660,87 +2580,52 @@ and translate_end_abstraction_loop (ectx : C.eval_ctx) (abs : V.abs)
(fun (arg, mp) -> mk_opt_mplace_texpression mp arg)
(List.combine inputs args_mplaces)
in
- let input_tys = (List.map (fun (x : texpression) -> x.ty)) args in
- let ret_ty =
- if effect_info.can_fail then mk_result_ty output.ty else output.ty
- in
(* Create the expression for the function:
- it is either a call to a top-level function, if we split the
forward/backward functions
- or a call to the variable we introduced for the backward function,
if we merge the forward/backward functions *)
let func =
- if !Config.return_back_funs then
- RegionGroupId.Map.find rg_id (Option.get loop_info.back_funs)
- else
- let func_ty = mk_arrows input_tys ret_ty in
- let func = Fun (FromLlbc (FunId fun_id, Some loop_id, Some rg_id)) in
- let func = { id = FunOrOp func; generics } in
- Some { e = Qualif func; ty = func_ty }
+ RegionGroupId.Map.find rg_id (Option.get loop_info.back_funs)
in
- (* **Optimization**:
- =================
- We do a small optimization here in case we split the forward/backward
- functions.
- If the backward function doesn't have any output, we don't introduce
- any function call.
- See the comment in {!Config.filter_useless_monadic_calls}.
-
- TODO: use an option to disallow backward functions from updating the state.
- TODO: a backward function which only gives back shared borrows shouldn't
- update the state (state updates should only be used for mutable borrows,
- with objects like Rc for instance).
- *)
- if
- (not !Config.return_back_funs)
- && !Config.filter_useless_monadic_calls
- && outputs = [] && nstate = None
- then (
- (* No outputs - we do a small sanity check: the backward function
- should have exactly the same number of inputs as the forward:
- this number can be different only if the forward function returned
- a value containing mutable borrows, which can't be the case... *)
- assert (List.length inputs = List.length fwd_inputs);
- next_e)
- else
- (* In case we merge the fwd/back functions we filter the backward
- functions elsewhere *)
- match func with
- | None -> next_e
- | Some func ->
- let call = mk_apps func args in
- (* Add meta-information - this is slightly hacky: we look at the
- values consumed by the abstraction (note that those come from
- *before* we applied the fixed-point context) and use them to
- guide the naming of the output vars.
-
- Also, we need to convert the backward outputs from patterns to
- variables.
-
- Finally, in practice, this works well only for loop bodies:
- we do this only in this case.
- TODO: improve the heuristics, to give weight to the hints for
- instance.
- *)
- let next_e =
- if ctx.inside_loop then
- let consumed_values = abs_to_consumed ctx ectx abs in
- let var_values = List.combine outputs consumed_values in
- let var_values =
- List.filter_map
- (fun (var, v) ->
- match var.Pure.value with
- | PatVar (var, _) -> Some (var, v)
- | _ -> None)
- var_values
- in
- let vars, values = List.split var_values in
- mk_emeta_symbolic_assignments vars values next_e
- else next_e
- in
+ (* We may have filtered the backward function elsewhere if it doesn't
+ do anything (doesn't consume anything and doesn't return anything) *)
+ match func with
+ | None -> next_e
+ | Some func ->
+ let call = mk_apps func args in
+ (* Add meta-information - this is slightly hacky: we look at the
+ values consumed by the abstraction (note that those come from
+ *before* we applied the fixed-point context) and use them to
+ guide the naming of the output vars.
+
+ Also, we need to convert the backward outputs from patterns to
+ variables.
+
+ Finally, in practice, this works well only for loop bodies:
+ we do this only in this case.
+ TODO: improve the heuristics, to give weight to the hints for
+ instance.
+ *)
+ let next_e =
+ if ctx.inside_loop then
+ let consumed_values = abs_to_consumed ctx ectx abs in
+ let var_values = List.combine outputs consumed_values in
+ let var_values =
+ List.filter_map
+ (fun (var, v) ->
+ match var.Pure.value with
+ | PatVar (var, _) -> Some (var, v)
+ | _ -> None)
+ var_values
+ in
+ let vars, values = List.split var_values in
+ mk_emeta_symbolic_assignments vars values next_e
+ else next_e
+ in
- (* Create the let-binding *)
- mk_let effect_info.can_fail output call next_e)
+ (* Create the let-binding *)
+ mk_let effect_info.can_fail output call next_e)
and translate_global_eval (gid : A.GlobalDeclId.id) (sval : V.symbolic_value)
(e : S.expression) (ctx : bs_ctx) : texpression =
@@ -3068,48 +2953,40 @@ and translate_forward_end (ectx : C.eval_ctx)
*)
let ctx =
(* Introduce variables for the inputs and the state variable
- and update the context. *)
- if !Config.return_back_funs then
- (* If the forward/backward functions are not split, we need
- to introduce fresh variables for the additional inputs,
- because they are locally introduced in a lambda *)
- let back_sg = RegionGroupId.Map.find bid ctx.sg.back_sg in
- let ctx, backward_inputs_no_state =
- fresh_vars back_sg.inputs_no_state ctx
- in
- let ctx, backward_inputs_with_state =
- if back_sg.effect_info.stateful then
- let ctx, var, _ = bs_ctx_fresh_state_var ctx in
- (ctx, backward_inputs_no_state @ [ var ])
- else (ctx, backward_inputs_no_state)
- in
- {
- ctx with
- backward_inputs_no_state =
- RegionGroupId.Map.add bid backward_inputs_no_state
- ctx.backward_inputs_no_state;
- backward_inputs_with_state =
- RegionGroupId.Map.add bid backward_inputs_with_state
- ctx.backward_inputs_with_state;
- }
- else
- (* Update the state variable *)
- let back_state_var =
- RegionGroupId.Map.find bid ctx.back_state_vars
- in
- { ctx with state_var = back_state_var }
+ and update the context.
+
+ We need to introduce fresh variables for the additional inputs,
+ because they are locally introduced in a lambda.
+ *)
+ let back_sg = RegionGroupId.Map.find bid ctx.sg.back_sg in
+ let ctx, backward_inputs_no_state =
+ fresh_vars back_sg.inputs_no_state ctx
+ in
+ let ctx, backward_inputs_with_state =
+ if back_sg.effect_info.stateful then
+ let ctx, var, _ = bs_ctx_fresh_state_var ctx in
+ (ctx, backward_inputs_no_state @ [ var ])
+ else (ctx, backward_inputs_no_state)
+ in
+ {
+ ctx with
+ backward_inputs_no_state =
+ RegionGroupId.Map.add bid backward_inputs_no_state
+ ctx.backward_inputs_no_state;
+ backward_inputs_with_state =
+ RegionGroupId.Map.add bid backward_inputs_with_state
+ ctx.backward_inputs_with_state;
+ }
in
let e = T.RegionGroupId.Map.find bid back_e in
let finish e =
(* Wrap in lambdas if necessary *)
- if !Config.return_back_funs then
- let inputs =
- RegionGroupId.Map.find bid ctx.backward_inputs_with_state
- in
- let places = List.map (fun _ -> None) inputs in
- mk_lambdas_from_vars inputs places e
- else e
+ let inputs =
+ RegionGroupId.Map.find bid ctx.backward_inputs_with_state
+ in
+ let places = List.map (fun _ -> None) inputs in
+ mk_lambdas_from_vars inputs places e
in
(ctx, e, finish)
in
@@ -3131,85 +3008,83 @@ and translate_forward_end (ectx : C.eval_ctx)
function, if needs be, and lookup the proper expression.
*)
let translate_end ctx =
- if !Config.return_back_funs then
- (* Compute the output of the forward function *)
- let fwd_effect_info = ctx.sg.fwd_info.effect_info in
- let ctx, pure_fwd_var = fresh_var None ctx.sg.fwd_output ctx in
- let fwd_e = translate_one_end ctx None in
-
- (* Introduce the backward functions. *)
- let back_el =
- List.map
- (fun ((gid, _) : RegionGroupId.id * back_sg_info) ->
- translate_one_end ctx (Some gid))
- (RegionGroupId.Map.bindings ctx.sg.back_sg)
- in
+ (* Compute the output of the forward function *)
+ let fwd_effect_info = ctx.sg.fwd_info.effect_info in
+ let ctx, pure_fwd_var = fresh_var None ctx.sg.fwd_output ctx in
+ let fwd_e = translate_one_end ctx None in
- (* Compute whether the backward expressions should be evaluated straight
- away or not (i.e., if we should bind them with monadic let-bindings
- or not). We evaluate them straight away if they can fail and have no
- inputs. *)
- let evaluate_backs =
- List.map
- (fun (sg : back_sg_info) ->
- if !Config.simplify_merged_fwd_backs then
- sg.inputs = [] && sg.effect_info.can_fail
- else false)
- (RegionGroupId.Map.values ctx.sg.back_sg)
- in
+ (* Introduce the backward functions. *)
+ let back_el =
+ List.map
+ (fun ((gid, _) : RegionGroupId.id * back_sg_info) ->
+ translate_one_end ctx (Some gid))
+ (RegionGroupId.Map.bindings ctx.sg.back_sg)
+ in
- (* Introduce variables for the backward functions.
- We lookup the LLBC definition in an attempt to derive pretty names
- for those functions. *)
- let _, back_vars = fresh_back_vars_for_current_fun ctx in
+ (* Compute whether the backward expressions should be evaluated straight
+ away or not (i.e., if we should bind them with monadic let-bindings
+ or not). We evaluate them straight away if they can fail and have no
+ inputs. *)
+ let evaluate_backs =
+ List.map
+ (fun (sg : back_sg_info) ->
+ if !Config.simplify_merged_fwd_backs then
+ sg.inputs = [] && sg.effect_info.can_fail
+ else false)
+ (RegionGroupId.Map.values ctx.sg.back_sg)
+ in
- (* Create the return expressions *)
- let vars =
- let back_vars = List.filter_map (fun x -> x) back_vars in
- if ctx.sg.fwd_info.ignore_output then back_vars
- else pure_fwd_var :: back_vars
- in
- let vars = List.map mk_texpression_from_var vars in
- let ret = mk_simpl_tuple_texpression vars in
-
- (* Introduce a fresh input state variable for the forward expression *)
- let _ctx, state_var, state_pat =
- if fwd_effect_info.stateful then
- let ctx, var, pat = bs_ctx_fresh_state_var ctx in
- (ctx, [ var ], [ pat ])
- else (ctx, [], [])
- in
+ (* Introduce variables for the backward functions.
+ We lookup the LLBC definition in an attempt to derive pretty names
+ for those functions. *)
+ let _, back_vars = fresh_back_vars_for_current_fun ctx in
- let state_var = List.map mk_texpression_from_var state_var in
- let ret = mk_simpl_tuple_texpression (state_var @ [ ret ]) in
- let ret = mk_result_return_texpression ret in
+ (* Create the return expressions *)
+ let vars =
+ let back_vars = List.filter_map (fun x -> x) back_vars in
+ if ctx.sg.fwd_info.ignore_output then back_vars
+ else pure_fwd_var :: back_vars
+ in
+ let vars = List.map mk_texpression_from_var vars in
+ let ret = mk_simpl_tuple_texpression vars in
+
+ (* Introduce a fresh input state variable for the forward expression *)
+ let _ctx, state_var, state_pat =
+ if fwd_effect_info.stateful then
+ let ctx, var, pat = bs_ctx_fresh_state_var ctx in
+ (ctx, [ var ], [ pat ])
+ else (ctx, [], [])
+ in
- (* Introduce all the let-bindings *)
+ let state_var = List.map mk_texpression_from_var state_var in
+ let ret = mk_simpl_tuple_texpression (state_var @ [ ret ]) in
+ let ret = mk_result_return_texpression ret in
- (* Combine:
- - the backward variables
- - whether we should evaluate the expression for the backward function
- (i.e., should we use a monadic let-binding or not - we do if the
- backward functions don't have inputs and can fail)
- - the expressions for the backward functions
- *)
- let back_vars_els =
- List.filter_map
- (fun (v, (eval, el)) ->
- match v with None -> None | Some v -> Some (v, eval, el))
- (List.combine back_vars (List.combine evaluate_backs back_el))
- in
- let e =
- List.fold_right
- (fun (var, evaluate, back_e) e ->
- mk_let evaluate (mk_typed_pattern_from_var var None) back_e e)
- back_vars_els ret
- in
- (* Bind the expression for the forward output *)
- let fwd_var = mk_typed_pattern_from_var pure_fwd_var None in
- let pat = mk_simpl_tuple_pattern (state_pat @ [ fwd_var ]) in
- mk_let fwd_effect_info.can_fail pat fwd_e e
- else translate_one_end ctx ctx.bid
+ (* Introduce all the let-bindings *)
+
+ (* Combine:
+ - the backward variables
+ - whether we should evaluate the expression for the backward function
+ (i.e., should we use a monadic let-binding or not - we do if the
+ backward functions don't have inputs and can fail)
+ - the expressions for the backward functions
+ *)
+ let back_vars_els =
+ List.filter_map
+ (fun (v, (eval, el)) ->
+ match v with None -> None | Some v -> Some (v, eval, el))
+ (List.combine back_vars (List.combine evaluate_backs back_el))
+ in
+ let e =
+ List.fold_right
+ (fun (var, evaluate, back_e) e ->
+ mk_let evaluate (mk_typed_pattern_from_var var None) back_e e)
+ back_vars_els ret
+ in
+ (* Bind the expression for the forward output *)
+ let fwd_var = mk_typed_pattern_from_var pure_fwd_var None in
+ let pat = mk_simpl_tuple_pattern (state_pat @ [ fwd_var ]) in
+ mk_let fwd_effect_info.can_fail pat fwd_e e
in
(* If we are (re-)entering a loop, we need to introduce a call to the
@@ -3279,24 +3154,22 @@ and translate_forward_end (ectx : C.eval_ctx)
backward functions of the outer function.
*)
let ctx, back_funs_map, back_funs =
- if !Config.return_back_funs then
- let ctx, back_vars = fresh_back_vars_for_current_fun ctx in
- let back_funs =
- List.filter_map
- (fun v ->
- match v with
- | None -> None
- | Some v -> Some (mk_typed_pattern_from_var v None))
- back_vars
- in
- let gids = RegionGroupId.Map.keys ctx.sg.back_sg in
- let back_funs_map =
- RegionGroupId.Map.of_list
- (List.combine gids
- (List.map (Option.map mk_texpression_from_var) back_vars))
- in
- (ctx, Some back_funs_map, back_funs)
- else (ctx, None, [])
+ let ctx, back_vars = fresh_back_vars_for_current_fun ctx in
+ let back_funs =
+ List.filter_map
+ (fun v ->
+ match v with
+ | None -> None
+ | Some v -> Some (mk_typed_pattern_from_var v None))
+ back_vars
+ in
+ let gids = RegionGroupId.Map.keys ctx.sg.back_sg in
+ let back_funs_map =
+ RegionGroupId.Map.of_list
+ (List.combine gids
+ (List.map (Option.map mk_texpression_from_var) back_vars))
+ in
+ (ctx, Some back_funs_map, back_funs)
in
(* Introduce patterns *)
@@ -3339,7 +3212,7 @@ and translate_forward_end (ectx : C.eval_ctx)
let out_pat = mk_simpl_tuple_pattern out_pats in
let loop_call =
- let fun_id = Fun (FromLlbc (FunId fid, Some loop_id, None)) in
+ let fun_id = Fun (FromLlbc (FunId fid, Some loop_id)) in
let func = { id = FunOrOp fun_id; generics = loop_info.generics } in
let input_tys = (List.map (fun (x : texpression) -> x.ty)) args in
let ret_ty =
@@ -3438,91 +3311,58 @@ and translate_loop (loop : S.loop) (ctx : bs_ctx) : texpression =
(* The output type of the loop function *)
let fwd_effect_info = { ctx.sg.fwd_info.effect_info with is_rec = true } in
let back_effect_infos, output_ty =
- if !Config.return_back_funs then
- (* The loop backward functions consume the same additional inputs as the parent
- function, but have custom outputs *)
- let back_sgs = RegionGroupId.Map.bindings ctx.sg.back_sg in
- let given_back_tys = RegionGroupId.Map.values rg_to_given_back_tys in
- let back_info_tys =
- List.map
- (fun (((id, back_sg), given_back) : (_ * back_sg_info) * ty list) ->
- (* Remark: the effect info of the backward function for the loop
- is almost the same as for the backward function of the parent function.
- Quite importantly, the fact that the function is stateful and/or can fail
- mostly depends on whether it has inputs or not, and the backward functions
- for the loops have the same inputs as the backward functions for the parent
- function.
- *)
- let effect_info = back_sg.effect_info in
- let effect_info = { effect_info with is_rec = true } in
- (* Compute the input/output types *)
- let inputs = List.map snd back_sg.inputs in
- let outputs = given_back in
- (* Filter if necessary *)
- let ty =
- if
- !Config.simplify_merged_fwd_backs && inputs = [] && outputs = []
- then None
- else
- let output = mk_simpl_tuple_ty outputs in
- let output =
- mk_back_output_ty_from_effect_info effect_info inputs output
- in
- let ty = mk_arrows inputs output in
- Some ty
- in
- ((id, effect_info), ty))
- (List.combine back_sgs given_back_tys)
- in
- let back_info = List.map fst back_info_tys in
- let back_info = RegionGroupId.Map.of_list back_info in
- let back_tys = List.filter_map snd back_info_tys in
- let output =
- if ctx.sg.fwd_info.ignore_output then back_tys
- else ctx.sg.fwd_output :: back_tys
- in
- let output = mk_simpl_tuple_ty output in
- let effect_info = ctx.sg.fwd_info.effect_info in
- let output =
- if effect_info.stateful then mk_simpl_tuple_ty [ mk_state_ty; output ]
- else output
- in
- let output =
- if effect_info.can_fail && inputs <> [] then mk_result_ty output
- else output
- in
- (back_info, output)
- else
- let back_info =
- RegionGroupId.Map.of_list
- (List.map
- (fun ((id, back_sg) : _ * back_sg_info) ->
- (id, { back_sg.effect_info with is_rec = true }))
- (RegionGroupId.Map.bindings ctx.sg.back_sg))
- in
- let output =
- match ctx.bid with
- | None ->
- (* Forward function: same type as the parent function *)
- (translate_fun_sig_from_decomposed ctx.sg None).output
- | Some rg_id ->
- (* Backward function: custom return type *)
- let doutputs =
- T.RegionGroupId.Map.find rg_id rg_to_given_back_tys
- in
- let output = mk_simpl_tuple_ty doutputs in
- let fwd_effect_info = ctx.sg.fwd_info.effect_info in
- let output =
- if fwd_effect_info.stateful then
- mk_simpl_tuple_ty [ mk_state_ty; output ]
- else output
- in
- let output =
- if fwd_effect_info.can_fail then mk_result_ty output else output
- in
- output
- in
- (back_info, output)
+ (* The loop backward functions consume the same additional inputs as the parent
+ function, but have custom outputs *)
+ let back_sgs = RegionGroupId.Map.bindings ctx.sg.back_sg in
+ let given_back_tys = RegionGroupId.Map.values rg_to_given_back_tys in
+ let back_info_tys =
+ List.map
+ (fun (((id, back_sg), given_back) : (_ * back_sg_info) * ty list) ->
+ (* Remark: the effect info of the backward function for the loop
+ is almost the same as for the backward function of the parent function.
+ Quite importantly, the fact that the function is stateful and/or can fail
+ mostly depends on whether it has inputs or not, and the backward functions
+ for the loops have the same inputs as the backward functions for the parent
+ function.
+ *)
+ let effect_info = back_sg.effect_info in
+ let effect_info = { effect_info with is_rec = true } in
+ (* Compute the input/output types *)
+ let inputs = List.map snd back_sg.inputs in
+ let outputs = given_back in
+ (* Filter if necessary *)
+ let ty =
+ if !Config.simplify_merged_fwd_backs && inputs = [] && outputs = []
+ then None
+ else
+ let output = mk_simpl_tuple_ty outputs in
+ let output =
+ mk_back_output_ty_from_effect_info effect_info inputs output
+ in
+ let ty = mk_arrows inputs output in
+ Some ty
+ in
+ ((id, effect_info), ty))
+ (List.combine back_sgs given_back_tys)
+ in
+ let back_info = List.map fst back_info_tys in
+ let back_info = RegionGroupId.Map.of_list back_info in
+ let back_tys = List.filter_map snd back_info_tys in
+ let output =
+ if ctx.sg.fwd_info.ignore_output then back_tys
+ else ctx.sg.fwd_output :: back_tys
+ in
+ let output = mk_simpl_tuple_ty output in
+ let effect_info = ctx.sg.fwd_info.effect_info in
+ let output =
+ if effect_info.stateful then mk_simpl_tuple_ty [ mk_state_ty; output ]
+ else output
+ in
+ let output =
+ if effect_info.can_fail && inputs <> [] then mk_result_ty output
+ else output
+ in
+ (back_info, output)
in
(* Add the loop information in the context *)
@@ -3708,31 +3548,26 @@ let wrap_in_match_fuel (fuel0 : VarId.id) (fuel : VarId.id) (body : texpression)
let translate_fun_decl (ctx : bs_ctx) (body : S.expression option) : fun_decl =
(* Translate *)
let def = ctx.fun_decl in
- let bid = ctx.bid in
+ assert (ctx.bid = None);
log#ldebug
(lazy
("SymbolicToPure.translate_fun_decl: "
^ name_to_string ctx def.name
- ^ " ("
- ^ Print.option_to_string T.RegionGroupId.to_string bid
- ^ ")\n"));
+ ^ "\n"));
(* Translate the declaration *)
let def_id = def.def_id in
let llbc_name = def.name in
let name = name_to_string ctx llbc_name in
(* Translate the signature *)
- let signature = translate_fun_sig_from_decomposed ctx.sg ctx.bid in
- let regions_hierarchy =
- FunIdMap.find (FRegular def_id) ctx.fun_ctx.regions_hierarchies
- in
+ let signature = translate_fun_sig_from_decomposed ctx.sg in
(* Translate the body, if there is *)
let body =
match body with
| None -> None
| Some body ->
let effect_info =
- get_fun_effect_info ctx (FunId (FRegular def_id)) None bid
+ get_fun_effect_info ctx (FunId (FRegular def_id)) None None
in
let body = translate_expression body ctx in
(* Add a match over the fuel, if necessary *)
@@ -3760,37 +3595,8 @@ let translate_fun_decl (ctx : bs_ctx) (body : S.expression option) : fun_decl =
if effect_info.stateful_group then [ mk_state_var ctx.state_var ]
else []
in
- (* Compute the list of (properly ordered) backward input variables *)
- let backward_inputs : var list =
- match bid with
- | None -> []
- | Some back_id ->
- assert (not !Config.return_back_funs);
- let parents_ids =
- list_ordered_ancestor_region_groups regions_hierarchy back_id
- in
- let backward_ids = List.append parents_ids [ back_id ] in
- List.concat
- (List.map
- (fun id ->
- T.RegionGroupId.Map.find id ctx.backward_inputs_no_state)
- backward_ids)
- in
- (* Introduce the backward input state (the state at call site of the
- * *backward* function), if necessary *)
- let back_state =
- if effect_info.stateful && Option.is_some bid then
- let state_var =
- RegionGroupId.Map.find (Option.get bid) ctx.back_state_vars
- in
- [ mk_state_var state_var ]
- else []
- in
(* Group the inputs together *)
- let inputs =
- List.concat
- [ fuel; ctx.forward_inputs; fwd_state; backward_inputs; back_state ]
- in
+ let inputs = List.concat [ fuel; ctx.forward_inputs; fwd_state ] in
let inputs_lvs =
List.map (fun v -> mk_typed_pattern_from_var v None) inputs
in
@@ -3799,16 +3605,10 @@ let translate_fun_decl (ctx : bs_ctx) (body : S.expression option) : fun_decl =
(lazy
("SymbolicToPure.translate_fun_decl: "
^ name_to_string ctx def.name
- ^ " ("
- ^ Print.option_to_string T.RegionGroupId.to_string bid
- ^ ")" ^ "\n- forward_inputs: "
+ ^ "\n- inputs: "
^ String.concat ", " (List.map show_var ctx.forward_inputs)
- ^ "\n- fwd_state: "
+ ^ "\n- state: "
^ String.concat ", " (List.map show_var fwd_state)
- ^ "\n- backward_inputs: "
- ^ String.concat ", " (List.map show_var backward_inputs)
- ^ "\n- back_state: "
- ^ String.concat ", " (List.map show_var back_state)
^ "\n- signature.inputs: "
^ String.concat ", "
(List.map (pure_ty_to_string ctx) signature.inputs)));
@@ -3837,7 +3637,6 @@ let translate_fun_decl (ctx : bs_ctx) (body : S.expression option) : fun_decl =
kind = def.kind;
num_loops;
loop_id;
- back_id = bid;
llbc_name;
name;
signature;