summaryrefslogtreecommitdiff
path: root/compiler/SymbolicToPure.ml
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--compiler/SymbolicToPure.ml266
1 files changed, 209 insertions, 57 deletions
diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml
index e2787271..1ce6c698 100644
--- a/compiler/SymbolicToPure.ml
+++ b/compiler/SymbolicToPure.ml
@@ -67,6 +67,18 @@ type call_info = {
Those inputs include the fuel and the state, if pertinent.
*)
+ back_funs : texpression RegionGroupId.Map.t option;
+ (** If we do not split between the forward/backward functions: the
+ variables we introduced for the backward functions.
+
+ Example:
+ {[
+ let x, back = Vec.index_mut n v in
+ ^^^^
+ here
+ ...
+ ]}
+ *)
}
[@@deriving show]
@@ -118,6 +130,8 @@ type loop_info = {
(** Body synthesis context *)
type bs_ctx = {
+ (* TODO: there are a lot of duplications with the various decls ctx *)
+ decls_ctx : C.decls_ctx;
type_ctx : type_ctx;
fun_ctx : fun_ctx;
global_ctx : global_ctx;
@@ -757,17 +771,27 @@ let translate_fun_id_or_trait_method_ref (ctx : bs_ctx)
TraitMethod (trait_ref, method_name, fun_decl_id)
let bs_ctx_register_forward_call (call_id : V.FunCallId.id) (forward : S.call)
- (args : texpression list) (ctx : bs_ctx) : bs_ctx =
+ (args : texpression list)
+ (back_funs : texpression RegionGroupId.Map.t option) (ctx : bs_ctx) : bs_ctx
+ =
let calls = ctx.calls in
assert (not (V.FunCallId.Map.mem call_id calls));
- let info = { forward; forward_inputs = args } in
+ let info = { forward; forward_inputs = args; back_funs } in
let calls = V.FunCallId.Map.add call_id info calls in
{ ctx with calls }
-(** [back_args]: the *additional* list of inputs received by the backward function *)
-let bs_ctx_register_backward_call (abs : V.abs) (call_id : V.FunCallId.id)
- (back_id : T.RegionGroupId.id) (back_args : texpression list) (ctx : bs_ctx)
- : bs_ctx * fun_or_op_id =
+(** [inherit_args]: the list of inputs inherited from the forward function and
+ the ancestors backward functions, if pertinent.
+ [back_args]: the *additional* list of inputs received by the backward function,
+ including the state.
+
+ Returns the updated context and the expression corresponding to the function.
+ *)
+let bs_ctx_register_backward_call (abs : V.abs) (effect_info : fun_effect_info)
+ (call_id : V.FunCallId.id) (back_id : T.RegionGroupId.id)
+ (inherited_args : texpression list) (back_args : texpression list)
+ (generics : generic_args) (output_ty : ty) (ctx : bs_ctx) :
+ bs_ctx * texpression =
(* Insert the abstraction in the call informations *)
let info = V.FunCallId.Map.find call_id ctx.calls in
let calls = V.FunCallId.Map.add call_id info ctx.calls in
@@ -777,16 +801,31 @@ let bs_ctx_register_backward_call (abs : V.abs) (call_id : V.FunCallId.id)
let abstractions =
V.AbstractionId.Map.add abs.abs_id (abs, back_args) abstractions
in
- (* Retrieve the fun_id *)
- let fun_id =
- match info.forward.call_id with
- | S.Fun (fid, _) ->
- let fid = translate_fun_id_or_trait_method_ref ctx fid in
- Fun (FromLlbc (fid, None, Some back_id))
- | S.Unop _ | S.Binop _ -> raise (Failure "Unreachable")
+ (* Compute the expression corresponding to the function *)
+ let func =
+ if !Config.return_back_funs then
+ (* Lookup the variable introduced for the backward function *)
+ RegionGroupId.Map.find back_id (Option.get info.back_funs)
+ else
+ (* Retrieve the fun_id *)
+ let fun_id =
+ match info.forward.call_id with
+ | S.Fun (fid, _) ->
+ let fid = translate_fun_id_or_trait_method_ref ctx fid in
+ Fun (FromLlbc (fid, None, Some back_id))
+ | S.Unop _ | S.Binop _ -> raise (Failure "Unreachable")
+ in
+ let args = List.append inherited_args back_args in
+ let input_tys = (List.map (fun (x : texpression) -> x.ty)) args in
+ let ret_ty =
+ if effect_info.can_fail then mk_result_ty output_ty else output_ty
+ in
+ let func_ty = mk_arrows input_tys ret_ty in
+ let func = { id = FunOrOp fun_id; generics } in
+ { e = Qualif func; ty = func_ty }
in
(* Update the context and return *)
- ({ ctx with calls; abstractions }, fun_id)
+ ({ ctx with calls; abstractions }, func)
(** List the ancestors of an abstraction *)
let list_ancestor_abstractions_ids (ctx : bs_ctx) (abs : V.abs)
@@ -878,15 +917,12 @@ let get_fun_effect_info (fun_infos : fun_info A.FunDeclId.Map.t)
We use [bid] ("backward function id") only if we split the forward
and the backward functions.
*)
-let translate_fun_sig_to_decomposed (decls_ctx : C.decls_ctx)
- (fun_id : A.fun_id) (sg : A.fun_sig) (input_names : string option list) :
- decomposed_fun_sig =
+let translate_fun_sig_with_regions_hierarchy_to_decomposed
+ (decls_ctx : C.decls_ctx) (fun_id : A.fun_id_or_trait_method_ref)
+ (regions_hierarchy : T.region_var_groups) (sg : A.fun_sig)
+ (input_names : string option list) : decomposed_fun_sig =
let fun_infos = decls_ctx.fun_ctx.fun_infos in
let type_infos = decls_ctx.type_ctx.type_infos in
- (* Retrieve the list of parent backward functions *)
- let regions_hierarchy =
- FunIdMap.find fun_id decls_ctx.fun_ctx.regions_hierarchies
- in
(* We need an evaluation context to normalize the types (to normalize the
associated types, etc. - for instance it may happen that the types
refer to the types associated to a trait ref, but where the trait ref
@@ -915,9 +951,7 @@ let translate_fun_sig_to_decomposed (decls_ctx : C.decls_ctx)
in
(* Is the forward function stateful, and can it fail? *)
- let fwd_effect_info =
- get_fun_effect_info fun_infos (FunId fun_id) None None
- in
+ let fwd_effect_info = get_fun_effect_info fun_infos fun_id None None in
(* Compute the forward inputs *)
let fwd_fuel = mk_fuel_input_ty_as_list fwd_effect_info in
let fwd_inputs_no_fuel_no_state =
@@ -1030,7 +1064,7 @@ let translate_fun_sig_to_decomposed (decls_ctx : C.decls_ctx)
RegionGroupId.id * back_sg_info =
let gid = rg.id in
let back_effect_info =
- get_fun_effect_info fun_infos (FunId fun_id) None (Some gid)
+ get_fun_effect_info fun_infos fun_id None (Some gid)
in
let inputs_no_state = translate_back_inputs_for_gid gid in
let inputs_no_state =
@@ -1072,6 +1106,16 @@ let translate_fun_sig_to_decomposed (decls_ctx : C.decls_ctx)
fwd_info;
}
+let translate_fun_sig_to_decomposed (decls_ctx : C.decls_ctx)
+ (fun_id : FunDeclId.id) (sg : A.fun_sig) (input_names : string option list)
+ : decomposed_fun_sig =
+ (* Retrieve the list of parent backward functions *)
+ let regions_hierarchy =
+ FunIdMap.find (FRegular fun_id) decls_ctx.fun_ctx.regions_hierarchies
+ in
+ translate_fun_sig_with_regions_hierarchy_to_decomposed decls_ctx
+ (FunId (FRegular fun_id)) regions_hierarchy sg input_names
+
let mk_output_ty_from_effect_info (effect_info : fun_effect_info) (ty : ty) : ty
=
let output =
@@ -1090,6 +1134,40 @@ let compute_back_tys (dsg : Pure.decomposed_fun_sig) : ty list =
mk_arrows inputs output)
(RegionGroupId.Map.values dsg.back_sg)
+(** Return the pure signature of a backward function, in the case the
+ forward/backward functions are merged (i.e., the forward functions
+ return the backward functions).
+
+ TODO: merge with {!translate_fun_sig_from_decomposed}
+ *)
+let translate_ret_back_fun_sig_from_decomposed (dsg : Pure.decomposed_fun_sig)
+ (gid : RegionGroupId.id) : fun_sig =
+ assert !Config.return_back_funs;
+
+ let generics = dsg.generics in
+ let llbc_generics = dsg.llbc_generics in
+ let preds = dsg.preds in
+ (* Compute the effects info *)
+ let fwd_info = dsg.fwd_info in
+ let back_effect_info =
+ RegionGroupId.Map.of_list
+ (List.map
+ (fun ((gid, info) : RegionGroupId.id * back_sg_info) ->
+ (gid, info.effect_info))
+ (RegionGroupId.Map.bindings dsg.back_sg))
+ in
+ (* Two cases depending on whether we split the forward/backward functions
+ or not *)
+ let mk_output_ty = mk_output_ty_from_effect_info in
+
+ let back_sg = RegionGroupId.Map.find gid dsg.back_sg in
+ let effect_info = back_sg.effect_info in
+ (* Do not prepend the forward inputs *)
+ let inputs = List.map snd back_sg.inputs in
+ let output = mk_simpl_tuple_ty back_sg.outputs in
+ let output = mk_output_ty effect_info output in
+ { generics; llbc_generics; preds; inputs; output; fwd_info; back_effect_info }
+
let translate_fun_sig_from_decomposed (dsg : Pure.decomposed_fun_sig)
(gid : RegionGroupId.id option) : fun_sig =
let generics = dsg.generics in
@@ -1774,7 +1852,7 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) :
let ctx, dest = fresh_var_for_symbolic_value call.dest ctx in
(* Retrieve the function id, and register the function call in the context
* if necessary. *)
- let ctx, fun_id, effect_info, args, out_state =
+ let ctx, fun_id, effect_info, args, back_funs, out_state =
match call.call_id with
| S.Fun (fid, call_id) ->
(* Regular function call *)
@@ -1798,9 +1876,80 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) :
(List.concat [ fuel; args; [ state_var ] ], ctx, Some nstate_var)
else (List.concat [ fuel; args ], ctx, None)
in
+ (* If we do not split the forward/backward functions: generate the
+ variables for the backward functions returned by the forward
+ function. *)
+ let ctx, back_funs_map, back_funs =
+ if !Config.return_back_funs then
+ (* We need to compute the signatures of the backward functions. *)
+ let sg = Option.get call.sg in
+ let decls_ctx = ctx.decls_ctx in
+ let dsg =
+ translate_fun_sig_with_regions_hierarchy_to_decomposed decls_ctx
+ fid call.regions_hierarchy sg
+ (List.map (fun _ -> None) sg.inputs)
+ in
+ let gids =
+ List.map
+ (fun (g : T.region_var_group) -> g.id)
+ call.regions_hierarchy
+ in
+ let back_sgs =
+ List.map (translate_ret_back_fun_sig_from_decomposed dsg) gids
+ in
+ (* Introduce variables for the backward functions *)
+ let back_tys =
+ List.map
+ (fun (sg : fun_sig) -> mk_arrows sg.inputs sg.output)
+ back_sgs
+ in
+ (* Compute a proper basename for the variables *)
+ let back_fun_name =
+ let name =
+ match fid with
+ | FunId (FAssumed fid) -> (
+ match fid with
+ | BoxNew -> "box_new"
+ | BoxFree -> "box_free"
+ | ArrayRepeat -> "array_repeat"
+ | ArrayIndexShared -> "index_shared"
+ | ArrayIndexMut -> "index_mut"
+ | ArrayToSliceShared -> "to_slice_shared"
+ | ArrayToSliceMut -> "to_slice_mut"
+ | SliceIndexShared -> "index_shared"
+ | SliceIndexMut -> "index_mut")
+ | FunId (FRegular fid) | TraitMethod (_, _, fid) -> (
+ let decl =
+ FunDeclId.Map.find fid ctx.fun_ctx.llbc_fun_decls
+ in
+ match Collections.List.last decl.name with
+ | PeIdent (s, _) -> s
+ | PeImpl _ ->
+ (* We shouldn't get there *)
+ raise (Failure "Unexpected"))
+ in
+ name ^ "_back"
+ in
+ let ctx, back_vars =
+ fresh_vars
+ (List.map (fun ty -> (Some back_fun_name, ty)) back_tys)
+ ctx
+ in
+ let back_funs =
+ List.map (fun v -> mk_typed_pattern_from_var v None) back_vars
+ in
+ let back_funs_map =
+ RegionGroupId.Map.of_list
+ (List.combine gids (List.map mk_texpression_from_var back_vars))
+ in
+ (ctx, Some back_funs_map, back_funs)
+ else (ctx, None, [])
+ in
(* Register the function call *)
- let ctx = bs_ctx_register_forward_call call_id call args ctx in
- (ctx, func, effect_info, args, out_state)
+ let ctx =
+ bs_ctx_register_forward_call call_id call args back_funs_map ctx
+ in
+ (ctx, func, effect_info, args, back_funs, out_state)
| S.Unop E.Not ->
let effect_info =
{
@@ -1811,7 +1960,7 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) :
is_rec = false;
}
in
- (ctx, Unop Not, effect_info, args, None)
+ (ctx, Unop Not, effect_info, args, [], None)
| S.Unop E.Neg -> (
match args with
| [ arg ] ->
@@ -1827,7 +1976,7 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) :
is_rec = false;
}
in
- (ctx, Unop (Neg int_ty), effect_info, args, None)
+ (ctx, Unop (Neg int_ty), effect_info, args, [], None)
| _ -> raise (Failure "Unreachable"))
| S.Unop (E.Cast cast_kind) -> (
match cast_kind with
@@ -1842,7 +1991,7 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) :
is_rec = false;
}
in
- (ctx, Unop (Cast (src_ty, tgt_ty)), effect_info, args, None)
+ (ctx, Unop (Cast (src_ty, tgt_ty)), effect_info, args, [], None)
| CastFnPtr _ -> raise (Failure "TODO: function casts"))
| S.Binop binop -> (
match args with
@@ -1862,11 +2011,12 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) :
is_rec = false;
}
in
- (ctx, Binop (binop, int_ty0), effect_info, args, None)
+ (ctx, Binop (binop, int_ty0), effect_info, args, [], None)
| _ -> raise (Failure "Unreachable"))
in
let dest_v =
let dest = mk_typed_pattern_from_var dest dest_mplace in
+ let dest = mk_simpl_tuple_pattern (dest :: back_funs) in
match out_state with
| None -> dest
| Some out_state -> mk_simpl_tuple_pattern [ out_state; dest ]
@@ -2026,9 +2176,11 @@ and translate_end_abstraction_fun_call (ectx : C.eval_ctx) (abs : V.abs)
else ([], ctx, None)
in
(* Concatenate all the inpus *)
- let inputs =
- List.concat [ fwd_inputs; back_ancestors_inputs; back_inputs; back_state ]
+ let inherited_inputs =
+ if !Config.return_back_funs then []
+ else List.concat [ fwd_inputs; back_ancestors_inputs ]
in
+ let back_inputs = List.append back_inputs back_state in
(* Retrieve the values given back by this function: those are the output
* values. We rely on the fact that there are no nested borrows to use the
* meta-place information from the input values given to the forward function
@@ -2046,43 +2198,43 @@ and translate_end_abstraction_fun_call (ectx : C.eval_ctx) (abs : V.abs)
| Some nstate -> mk_simpl_tuple_pattern [ nstate; output ]
in
(* Retrieve the function id, and register the function call in the context
- * if necessary *)
+ if necessary.Arith_status *)
let ctx, func =
- bs_ctx_register_backward_call abs call_id rg_id back_inputs ctx
+ bs_ctx_register_backward_call abs effect_info call_id rg_id inherited_inputs
+ back_inputs generics output.ty ctx
in
(* Translate the next expression *)
let next_e = translate_expression e ctx in
(* Put everything together *)
+ let inputs = List.append inherited_inputs back_inputs in
let args_mplaces = List.map (fun _ -> None) inputs in
let args =
List.map
(fun (arg, mp) -> mk_opt_mplace_texpression mp arg)
(List.combine inputs args_mplaces)
in
- let input_tys = (List.map (fun (x : texpression) -> x.ty)) args in
- let ret_ty =
- if effect_info.can_fail then mk_result_ty output.ty else output.ty
- in
- let func_ty = mk_arrows input_tys ret_ty in
- let func = { id = FunOrOp func; generics } in
- let func = { e = Qualif func; ty = func_ty } in
let call = mk_apps func args in
(* **Optimization**:
- * =================
- * We do a small optimization here: if the backward function doesn't
- * have any output, we don't introduce any function call.
- * See the comment in {!Config.filter_useless_monadic_calls}.
- *
- * TODO: use an option to disallow backward functions from updating the state.
- * TODO: a backward function which only gives back shared borrows shouldn't
- * update the state (state updates should only be used for mutable borrows,
- * with objects like Rc for instance).
- *)
- if !Config.filter_useless_monadic_calls && outputs = [] && nstate = None then (
+ =================
+ We do a small optimization here if we split the forward/backward functions.
+ If the backward function doesn't have any output, we don't introduce any function
+ call.
+ See the comment in {!Config.filter_useless_monadic_calls}.
+
+ TODO: use an option to disallow backward functions from updating the state.
+ TODO: a backward function which only gives back shared borrows shouldn't
+ update the state (state updates should only be used for mutable borrows,
+ with objects like Rc for instance).
+ *)
+ if
+ (not !Config.return_back_funs)
+ && !Config.filter_useless_monadic_calls
+ && outputs = [] && nstate = None
+ then (
(* No outputs - we do a small sanity check: the backward function
- * should have exactly the same number of inputs as the forward:
- * this number can be different only if the forward function returned
- * a value containing mutable borrows, which can't be the case... *)
+ should have exactly the same number of inputs as the forward:
+ this number can be different only if the forward function returned
+ a value containing mutable borrows, which can't be the case... *)
assert (List.length inputs = List.length fwd_inputs);
next_e)
else mk_let effect_info.can_fail output call next_e