summaryrefslogtreecommitdiff
path: root/compiler
diff options
context:
space:
mode:
authorSon Ho2023-12-21 14:37:43 +0100
committerSon Ho2023-12-21 14:37:43 +0100
commit8835d87df111d09122267fadc9a32f16b52d234a (patch)
tree43a1fd0e3ec0e8b94834744396b86bbd3084c888 /compiler
parente90b23a0d42e2ea6805c88d6eaa4f9e5370a1dc1 (diff)
Make good progress on merging the fwd/back functions
Diffstat (limited to '')
-rw-r--r--compiler/Config.ml2
-rw-r--r--compiler/Extract.ml4
-rw-r--r--compiler/InterpreterStatements.ml47
-rw-r--r--compiler/PureUtils.ml19
-rw-r--r--compiler/SymbolicAst.ml4
-rw-r--r--compiler/SymbolicToPure.ml266
-rw-r--r--compiler/SynthesizeSymbolic.ml16
-rw-r--r--compiler/Translate.ml3
8 files changed, 274 insertions, 87 deletions
diff --git a/compiler/Config.ml b/compiler/Config.ml
index c8f3ed58..b8af6c6d 100644
--- a/compiler/Config.ml
+++ b/compiler/Config.ml
@@ -153,7 +153,7 @@ let loop_fixed_point_max_num_iters = 2
return (x :: ls)))
]}
*)
-let return_back_funs = ref false
+let return_back_funs = ref true
(** Forbids using field projectors for structures.
diff --git a/compiler/Extract.ml b/compiler/Extract.ml
index 3429cd11..46cf8c4a 100644
--- a/compiler/Extract.ml
+++ b/compiler/Extract.ml
@@ -1332,7 +1332,9 @@ let extract_fun_comment (ctx : extraction_ctx) (fmt : F.formatter)
in
let fwd_back_comment =
match def.back_id with
- | None -> [ "forward function" ]
+ | None ->
+ if !Config.return_back_funs then [ "function definition" ]
+ else [ "forward function" ]
| Some id ->
(* Check if there is only one backward function, and no forward function *)
if (not keep_fwd) && num_backs = 1 then
diff --git a/compiler/InterpreterStatements.ml b/compiler/InterpreterStatements.ml
index da617c64..94c65b5c 100644
--- a/compiler/InterpreterStatements.ml
+++ b/compiler/InterpreterStatements.ml
@@ -728,7 +728,12 @@ let create_push_abstractions_from_abs_region_groups
to a trait clause but directly to the method provided in the trait declaration.
*)
let eval_transparent_function_call_symbolic_inst (call : call) (ctx : eval_ctx)
- : fun_id_or_trait_method_ref * generic_args * fun_decl * inst_fun_sig =
+ :
+ fun_id_or_trait_method_ref
+ * generic_args
+ * fun_decl
+ * region_var_groups
+ * inst_fun_sig =
match call.func with
| FnOpMove _ ->
(* Closure case: TODO *)
@@ -753,7 +758,7 @@ let eval_transparent_function_call_symbolic_inst (call : call) (ctx : eval_ctx)
instantiate_fun_sig ctx func.generics tr_self def.signature
regions_hierarchy
in
- (func.func, func.generics, def, inst_sg)
+ (func.func, func.generics, def, regions_hierarchy, inst_sg)
| FunId (FAssumed _) ->
(* Unreachable: must be a transparent function *)
raise (Failure "Unreachable")
@@ -806,7 +811,7 @@ let eval_transparent_function_call_symbolic_inst (call : call) (ctx : eval_ctx)
we also need to update the generics.
*)
let func = FunId fid in
- (func, generics, method_def, inst_sg)
+ (func, generics, method_def, regions_hierarchy, inst_sg)
| None ->
(* If not found, lookup the methods provided by the trait *declaration*
(remember: for now, we forbid overriding provided methods) *)
@@ -860,7 +865,11 @@ let eval_transparent_function_call_symbolic_inst (call : call) (ctx : eval_ctx)
instantiate_fun_sig ctx all_generics tr_self
method_def.signature regions_hierarchy
in
- (func.func, func.generics, method_def, inst_sg))
+ ( func.func,
+ func.generics,
+ method_def,
+ regions_hierarchy,
+ inst_sg ))
| _ ->
(* We are using a local clause - we lookup the trait decl *)
let trait_decl =
@@ -891,7 +900,8 @@ let eval_transparent_function_call_symbolic_inst (call : call) (ctx : eval_ctx)
instantiate_fun_sig ctx generics tr_self method_def.signature
regions_hierarchy
in
- (func.func, func.generics, method_def, inst_sg)))
+ (func.func, func.generics, method_def, regions_hierarchy, inst_sg)
+ ))
(** Evaluate a statement *)
let rec eval_statement (config : config) (st : statement) : st_cm_fun =
@@ -1277,14 +1287,14 @@ and eval_transparent_function_call_concrete (config : config)
and eval_transparent_function_call_symbolic (config : config) (call : call) :
st_cm_fun =
fun cf ctx ->
- let func, generics, def, inst_sg =
+ let func, generics, def, regions_hierarchy, inst_sg =
eval_transparent_function_call_symbolic_inst call ctx
in
(* Sanity check *)
assert (List.length call.args = List.length def.signature.inputs);
(* Evaluate the function call *)
- eval_function_call_symbolic_from_inst_sig config func inst_sg generics
- call.args call.dest cf ctx
+ eval_function_call_symbolic_from_inst_sig config func def.signature
+ regions_hierarchy inst_sg generics call.args call.dest cf ctx
(** Evaluate a function call in symbolic mode by using the function signature.
@@ -1298,7 +1308,8 @@ and eval_transparent_function_call_symbolic (config : config) (call : call) :
trait ref as input.
*)
and eval_function_call_symbolic_from_inst_sig (config : config)
- (fid : fun_id_or_trait_method_ref) (inst_sg : inst_fun_sig)
+ (fid : fun_id_or_trait_method_ref) (sg : fun_sig)
+ (regions_hierarchy : region_var_groups) (inst_sg : inst_fun_sig)
(generics : generic_args) (args : operand list) (dest : place) : st_cm_fun =
fun cf ctx ->
log#ldebug
@@ -1378,8 +1389,8 @@ and eval_function_call_symbolic_from_inst_sig (config : config)
let expr = cf ctx in
(* Synthesize the symbolic AST *)
- S.synthesize_regular_function_call fid call_id ctx abs_ids generics args
- args_places ret_spc dest_place expr
+ S.synthesize_regular_function_call fid call_id ctx sg regions_hierarchy
+ abs_ids generics args args_places ret_spc dest_place expr
in
let cc = comp cc cf_call in
@@ -1468,7 +1479,7 @@ and eval_assumed_function_call_symbolic (config : config) (fid : assumed_fun_id)
(* In symbolic mode, the behaviour of a function call is completely defined
* by the signature of the function: we thus simply generate correctly
* instantiated signatures, and delegate the work to an auxiliary function *)
- let inst_sig =
+ let sg, regions_hierarchy, inst_sig =
match fid with
| BoxFree ->
(* Should have been treated above *)
@@ -1480,14 +1491,16 @@ and eval_assumed_function_call_symbolic (config : config) (fid : assumed_fun_id)
in
(* There shouldn't be any reference to Self *)
let tr_self = UnknownTrait __FUNCTION__ in
- instantiate_fun_sig ctx generics tr_self
- (Assumed.get_assumed_fun_sig fid)
- regions_hierarchy
+ let sg = Assumed.get_assumed_fun_sig fid in
+ let inst_sg =
+ instantiate_fun_sig ctx generics tr_self sg regions_hierarchy
+ in
+ (sg, regions_hierarchy, inst_sg)
in
(* Evaluate the function call *)
- eval_function_call_symbolic_from_inst_sig config (FunId (FAssumed fid))
- inst_sig generics args dest cf ctx
+ eval_function_call_symbolic_from_inst_sig config (FunId (FAssumed fid)) sg
+ regions_hierarchy inst_sig generics args dest cf ctx
(** Evaluate a statement seen as a function body *)
and eval_function_body (config : config) (body : statement) : st_cm_fun =
diff --git a/compiler/PureUtils.ml b/compiler/PureUtils.ml
index 6e86578c..6579e84c 100644
--- a/compiler/PureUtils.ml
+++ b/compiler/PureUtils.ml
@@ -321,14 +321,23 @@ let destruct_apps (e : texpression) : texpression * texpression list =
(** Make an [App (app, arg)] expression *)
let mk_app (app : texpression) (arg : texpression) : texpression =
+ let raise_or_return msg =
+ if !Config.fail_hard then raise (Failure msg)
+ else
+ let e = App (app, arg) in
+ (* Dummy type - TODO: introduce an error type *)
+ let ty = app.ty in
+ { e; ty }
+ in
match app.ty with
| TArrow (ty0, ty1) ->
(* Sanity check *)
- assert (ty0 = arg.ty);
- let e = App (app, arg) in
- let ty = ty1 in
- { e; ty }
- | _ -> raise (Failure "Expected an arrow type")
+ if ty0 <> arg.ty then raise_or_return "App: wrong input type"
+ else
+ let e = App (app, arg) in
+ let ty = ty1 in
+ { e; ty }
+ | _ -> raise_or_return "Expected an arrow type"
(** The reverse of {!destruct_apps} *)
let mk_apps (app : texpression) (args : texpression list) : texpression =
diff --git a/compiler/SymbolicAst.ml b/compiler/SymbolicAst.ml
index 53f99b7f..54d207d9 100644
--- a/compiler/SymbolicAst.ml
+++ b/compiler/SymbolicAst.ml
@@ -42,7 +42,11 @@ type call = {
evaluated). We need it to compute the translated values for shared
borrows (we need to perform lookups).
*)
+ sg : fun_sig option;
+ (** The uninstantiated function signature, if this is not a unop/binop *)
+ regions_hierarchy : region_var_groups;
abstractions : AbstractionId.id list;
+ (** The region abstractions introduced upon calling the function *)
generics : generic_args;
args : typed_value list;
args_places : mplace option list; (** Meta information *)
diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml
index e2787271..1ce6c698 100644
--- a/compiler/SymbolicToPure.ml
+++ b/compiler/SymbolicToPure.ml
@@ -67,6 +67,18 @@ type call_info = {
Those inputs include the fuel and the state, if pertinent.
*)
+ back_funs : texpression RegionGroupId.Map.t option;
+ (** If we do not split between the forward/backward functions: the
+ variables we introduced for the backward functions.
+
+ Example:
+ {[
+ let x, back = Vec.index_mut n v in
+ ^^^^
+ here
+ ...
+ ]}
+ *)
}
[@@deriving show]
@@ -118,6 +130,8 @@ type loop_info = {
(** Body synthesis context *)
type bs_ctx = {
+ (* TODO: there are a lot of duplications with the various decls ctx *)
+ decls_ctx : C.decls_ctx;
type_ctx : type_ctx;
fun_ctx : fun_ctx;
global_ctx : global_ctx;
@@ -757,17 +771,27 @@ let translate_fun_id_or_trait_method_ref (ctx : bs_ctx)
TraitMethod (trait_ref, method_name, fun_decl_id)
let bs_ctx_register_forward_call (call_id : V.FunCallId.id) (forward : S.call)
- (args : texpression list) (ctx : bs_ctx) : bs_ctx =
+ (args : texpression list)
+ (back_funs : texpression RegionGroupId.Map.t option) (ctx : bs_ctx) : bs_ctx
+ =
let calls = ctx.calls in
assert (not (V.FunCallId.Map.mem call_id calls));
- let info = { forward; forward_inputs = args } in
+ let info = { forward; forward_inputs = args; back_funs } in
let calls = V.FunCallId.Map.add call_id info calls in
{ ctx with calls }
-(** [back_args]: the *additional* list of inputs received by the backward function *)
-let bs_ctx_register_backward_call (abs : V.abs) (call_id : V.FunCallId.id)
- (back_id : T.RegionGroupId.id) (back_args : texpression list) (ctx : bs_ctx)
- : bs_ctx * fun_or_op_id =
+(** [inherit_args]: the list of inputs inherited from the forward function and
+ the ancestors backward functions, if pertinent.
+ [back_args]: the *additional* list of inputs received by the backward function,
+ including the state.
+
+ Returns the updated context and the expression corresponding to the function.
+ *)
+let bs_ctx_register_backward_call (abs : V.abs) (effect_info : fun_effect_info)
+ (call_id : V.FunCallId.id) (back_id : T.RegionGroupId.id)
+ (inherited_args : texpression list) (back_args : texpression list)
+ (generics : generic_args) (output_ty : ty) (ctx : bs_ctx) :
+ bs_ctx * texpression =
(* Insert the abstraction in the call informations *)
let info = V.FunCallId.Map.find call_id ctx.calls in
let calls = V.FunCallId.Map.add call_id info ctx.calls in
@@ -777,16 +801,31 @@ let bs_ctx_register_backward_call (abs : V.abs) (call_id : V.FunCallId.id)
let abstractions =
V.AbstractionId.Map.add abs.abs_id (abs, back_args) abstractions
in
- (* Retrieve the fun_id *)
- let fun_id =
- match info.forward.call_id with
- | S.Fun (fid, _) ->
- let fid = translate_fun_id_or_trait_method_ref ctx fid in
- Fun (FromLlbc (fid, None, Some back_id))
- | S.Unop _ | S.Binop _ -> raise (Failure "Unreachable")
+ (* Compute the expression corresponding to the function *)
+ let func =
+ if !Config.return_back_funs then
+ (* Lookup the variable introduced for the backward function *)
+ RegionGroupId.Map.find back_id (Option.get info.back_funs)
+ else
+ (* Retrieve the fun_id *)
+ let fun_id =
+ match info.forward.call_id with
+ | S.Fun (fid, _) ->
+ let fid = translate_fun_id_or_trait_method_ref ctx fid in
+ Fun (FromLlbc (fid, None, Some back_id))
+ | S.Unop _ | S.Binop _ -> raise (Failure "Unreachable")
+ in
+ let args = List.append inherited_args back_args in
+ let input_tys = (List.map (fun (x : texpression) -> x.ty)) args in
+ let ret_ty =
+ if effect_info.can_fail then mk_result_ty output_ty else output_ty
+ in
+ let func_ty = mk_arrows input_tys ret_ty in
+ let func = { id = FunOrOp fun_id; generics } in
+ { e = Qualif func; ty = func_ty }
in
(* Update the context and return *)
- ({ ctx with calls; abstractions }, fun_id)
+ ({ ctx with calls; abstractions }, func)
(** List the ancestors of an abstraction *)
let list_ancestor_abstractions_ids (ctx : bs_ctx) (abs : V.abs)
@@ -878,15 +917,12 @@ let get_fun_effect_info (fun_infos : fun_info A.FunDeclId.Map.t)
We use [bid] ("backward function id") only if we split the forward
and the backward functions.
*)
-let translate_fun_sig_to_decomposed (decls_ctx : C.decls_ctx)
- (fun_id : A.fun_id) (sg : A.fun_sig) (input_names : string option list) :
- decomposed_fun_sig =
+let translate_fun_sig_with_regions_hierarchy_to_decomposed
+ (decls_ctx : C.decls_ctx) (fun_id : A.fun_id_or_trait_method_ref)
+ (regions_hierarchy : T.region_var_groups) (sg : A.fun_sig)
+ (input_names : string option list) : decomposed_fun_sig =
let fun_infos = decls_ctx.fun_ctx.fun_infos in
let type_infos = decls_ctx.type_ctx.type_infos in
- (* Retrieve the list of parent backward functions *)
- let regions_hierarchy =
- FunIdMap.find fun_id decls_ctx.fun_ctx.regions_hierarchies
- in
(* We need an evaluation context to normalize the types (to normalize the
associated types, etc. - for instance it may happen that the types
refer to the types associated to a trait ref, but where the trait ref
@@ -915,9 +951,7 @@ let translate_fun_sig_to_decomposed (decls_ctx : C.decls_ctx)
in
(* Is the forward function stateful, and can it fail? *)
- let fwd_effect_info =
- get_fun_effect_info fun_infos (FunId fun_id) None None
- in
+ let fwd_effect_info = get_fun_effect_info fun_infos fun_id None None in
(* Compute the forward inputs *)
let fwd_fuel = mk_fuel_input_ty_as_list fwd_effect_info in
let fwd_inputs_no_fuel_no_state =
@@ -1030,7 +1064,7 @@ let translate_fun_sig_to_decomposed (decls_ctx : C.decls_ctx)
RegionGroupId.id * back_sg_info =
let gid = rg.id in
let back_effect_info =
- get_fun_effect_info fun_infos (FunId fun_id) None (Some gid)
+ get_fun_effect_info fun_infos fun_id None (Some gid)
in
let inputs_no_state = translate_back_inputs_for_gid gid in
let inputs_no_state =
@@ -1072,6 +1106,16 @@ let translate_fun_sig_to_decomposed (decls_ctx : C.decls_ctx)
fwd_info;
}
+let translate_fun_sig_to_decomposed (decls_ctx : C.decls_ctx)
+ (fun_id : FunDeclId.id) (sg : A.fun_sig) (input_names : string option list)
+ : decomposed_fun_sig =
+ (* Retrieve the list of parent backward functions *)
+ let regions_hierarchy =
+ FunIdMap.find (FRegular fun_id) decls_ctx.fun_ctx.regions_hierarchies
+ in
+ translate_fun_sig_with_regions_hierarchy_to_decomposed decls_ctx
+ (FunId (FRegular fun_id)) regions_hierarchy sg input_names
+
let mk_output_ty_from_effect_info (effect_info : fun_effect_info) (ty : ty) : ty
=
let output =
@@ -1090,6 +1134,40 @@ let compute_back_tys (dsg : Pure.decomposed_fun_sig) : ty list =
mk_arrows inputs output)
(RegionGroupId.Map.values dsg.back_sg)
+(** Return the pure signature of a backward function, in the case the
+ forward/backward functions are merged (i.e., the forward functions
+ return the backward functions).
+
+ TODO: merge with {!translate_fun_sig_from_decomposed}
+ *)
+let translate_ret_back_fun_sig_from_decomposed (dsg : Pure.decomposed_fun_sig)
+ (gid : RegionGroupId.id) : fun_sig =
+ assert !Config.return_back_funs;
+
+ let generics = dsg.generics in
+ let llbc_generics = dsg.llbc_generics in
+ let preds = dsg.preds in
+ (* Compute the effects info *)
+ let fwd_info = dsg.fwd_info in
+ let back_effect_info =
+ RegionGroupId.Map.of_list
+ (List.map
+ (fun ((gid, info) : RegionGroupId.id * back_sg_info) ->
+ (gid, info.effect_info))
+ (RegionGroupId.Map.bindings dsg.back_sg))
+ in
+ (* Two cases depending on whether we split the forward/backward functions
+ or not *)
+ let mk_output_ty = mk_output_ty_from_effect_info in
+
+ let back_sg = RegionGroupId.Map.find gid dsg.back_sg in
+ let effect_info = back_sg.effect_info in
+ (* Do not prepend the forward inputs *)
+ let inputs = List.map snd back_sg.inputs in
+ let output = mk_simpl_tuple_ty back_sg.outputs in
+ let output = mk_output_ty effect_info output in
+ { generics; llbc_generics; preds; inputs; output; fwd_info; back_effect_info }
+
let translate_fun_sig_from_decomposed (dsg : Pure.decomposed_fun_sig)
(gid : RegionGroupId.id option) : fun_sig =
let generics = dsg.generics in
@@ -1774,7 +1852,7 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) :
let ctx, dest = fresh_var_for_symbolic_value call.dest ctx in
(* Retrieve the function id, and register the function call in the context
* if necessary. *)
- let ctx, fun_id, effect_info, args, out_state =
+ let ctx, fun_id, effect_info, args, back_funs, out_state =
match call.call_id with
| S.Fun (fid, call_id) ->
(* Regular function call *)
@@ -1798,9 +1876,80 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) :
(List.concat [ fuel; args; [ state_var ] ], ctx, Some nstate_var)
else (List.concat [ fuel; args ], ctx, None)
in
+ (* If we do not split the forward/backward functions: generate the
+ variables for the backward functions returned by the forward
+ function. *)
+ let ctx, back_funs_map, back_funs =
+ if !Config.return_back_funs then
+ (* We need to compute the signatures of the backward functions. *)
+ let sg = Option.get call.sg in
+ let decls_ctx = ctx.decls_ctx in
+ let dsg =
+ translate_fun_sig_with_regions_hierarchy_to_decomposed decls_ctx
+ fid call.regions_hierarchy sg
+ (List.map (fun _ -> None) sg.inputs)
+ in
+ let gids =
+ List.map
+ (fun (g : T.region_var_group) -> g.id)
+ call.regions_hierarchy
+ in
+ let back_sgs =
+ List.map (translate_ret_back_fun_sig_from_decomposed dsg) gids
+ in
+ (* Introduce variables for the backward functions *)
+ let back_tys =
+ List.map
+ (fun (sg : fun_sig) -> mk_arrows sg.inputs sg.output)
+ back_sgs
+ in
+ (* Compute a proper basename for the variables *)
+ let back_fun_name =
+ let name =
+ match fid with
+ | FunId (FAssumed fid) -> (
+ match fid with
+ | BoxNew -> "box_new"
+ | BoxFree -> "box_free"
+ | ArrayRepeat -> "array_repeat"
+ | ArrayIndexShared -> "index_shared"
+ | ArrayIndexMut -> "index_mut"
+ | ArrayToSliceShared -> "to_slice_shared"
+ | ArrayToSliceMut -> "to_slice_mut"
+ | SliceIndexShared -> "index_shared"
+ | SliceIndexMut -> "index_mut")
+ | FunId (FRegular fid) | TraitMethod (_, _, fid) -> (
+ let decl =
+ FunDeclId.Map.find fid ctx.fun_ctx.llbc_fun_decls
+ in
+ match Collections.List.last decl.name with
+ | PeIdent (s, _) -> s
+ | PeImpl _ ->
+ (* We shouldn't get there *)
+ raise (Failure "Unexpected"))
+ in
+ name ^ "_back"
+ in
+ let ctx, back_vars =
+ fresh_vars
+ (List.map (fun ty -> (Some back_fun_name, ty)) back_tys)
+ ctx
+ in
+ let back_funs =
+ List.map (fun v -> mk_typed_pattern_from_var v None) back_vars
+ in
+ let back_funs_map =
+ RegionGroupId.Map.of_list
+ (List.combine gids (List.map mk_texpression_from_var back_vars))
+ in
+ (ctx, Some back_funs_map, back_funs)
+ else (ctx, None, [])
+ in
(* Register the function call *)
- let ctx = bs_ctx_register_forward_call call_id call args ctx in
- (ctx, func, effect_info, args, out_state)
+ let ctx =
+ bs_ctx_register_forward_call call_id call args back_funs_map ctx
+ in
+ (ctx, func, effect_info, args, back_funs, out_state)
| S.Unop E.Not ->
let effect_info =
{
@@ -1811,7 +1960,7 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) :
is_rec = false;
}
in
- (ctx, Unop Not, effect_info, args, None)
+ (ctx, Unop Not, effect_info, args, [], None)
| S.Unop E.Neg -> (
match args with
| [ arg ] ->
@@ -1827,7 +1976,7 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) :
is_rec = false;
}
in
- (ctx, Unop (Neg int_ty), effect_info, args, None)
+ (ctx, Unop (Neg int_ty), effect_info, args, [], None)
| _ -> raise (Failure "Unreachable"))
| S.Unop (E.Cast cast_kind) -> (
match cast_kind with
@@ -1842,7 +1991,7 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) :
is_rec = false;
}
in
- (ctx, Unop (Cast (src_ty, tgt_ty)), effect_info, args, None)
+ (ctx, Unop (Cast (src_ty, tgt_ty)), effect_info, args, [], None)
| CastFnPtr _ -> raise (Failure "TODO: function casts"))
| S.Binop binop -> (
match args with
@@ -1862,11 +2011,12 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) :
is_rec = false;
}
in
- (ctx, Binop (binop, int_ty0), effect_info, args, None)
+ (ctx, Binop (binop, int_ty0), effect_info, args, [], None)
| _ -> raise (Failure "Unreachable"))
in
let dest_v =
let dest = mk_typed_pattern_from_var dest dest_mplace in
+ let dest = mk_simpl_tuple_pattern (dest :: back_funs) in
match out_state with
| None -> dest
| Some out_state -> mk_simpl_tuple_pattern [ out_state; dest ]
@@ -2026,9 +2176,11 @@ and translate_end_abstraction_fun_call (ectx : C.eval_ctx) (abs : V.abs)
else ([], ctx, None)
in
(* Concatenate all the inpus *)
- let inputs =
- List.concat [ fwd_inputs; back_ancestors_inputs; back_inputs; back_state ]
+ let inherited_inputs =
+ if !Config.return_back_funs then []
+ else List.concat [ fwd_inputs; back_ancestors_inputs ]
in
+ let back_inputs = List.append back_inputs back_state in
(* Retrieve the values given back by this function: those are the output
* values. We rely on the fact that there are no nested borrows to use the
* meta-place information from the input values given to the forward function
@@ -2046,43 +2198,43 @@ and translate_end_abstraction_fun_call (ectx : C.eval_ctx) (abs : V.abs)
| Some nstate -> mk_simpl_tuple_pattern [ nstate; output ]
in
(* Retrieve the function id, and register the function call in the context
- * if necessary *)
+ if necessary.Arith_status *)
let ctx, func =
- bs_ctx_register_backward_call abs call_id rg_id back_inputs ctx
+ bs_ctx_register_backward_call abs effect_info call_id rg_id inherited_inputs
+ back_inputs generics output.ty ctx
in
(* Translate the next expression *)
let next_e = translate_expression e ctx in
(* Put everything together *)
+ let inputs = List.append inherited_inputs back_inputs in
let args_mplaces = List.map (fun _ -> None) inputs in
let args =
List.map
(fun (arg, mp) -> mk_opt_mplace_texpression mp arg)
(List.combine inputs args_mplaces)
in
- let input_tys = (List.map (fun (x : texpression) -> x.ty)) args in
- let ret_ty =
- if effect_info.can_fail then mk_result_ty output.ty else output.ty
- in
- let func_ty = mk_arrows input_tys ret_ty in
- let func = { id = FunOrOp func; generics } in
- let func = { e = Qualif func; ty = func_ty } in
let call = mk_apps func args in
(* **Optimization**:
- * =================
- * We do a small optimization here: if the backward function doesn't
- * have any output, we don't introduce any function call.
- * See the comment in {!Config.filter_useless_monadic_calls}.
- *
- * TODO: use an option to disallow backward functions from updating the state.
- * TODO: a backward function which only gives back shared borrows shouldn't
- * update the state (state updates should only be used for mutable borrows,
- * with objects like Rc for instance).
- *)
- if !Config.filter_useless_monadic_calls && outputs = [] && nstate = None then (
+ =================
+ We do a small optimization here if we split the forward/backward functions.
+ If the backward function doesn't have any output, we don't introduce any function
+ call.
+ See the comment in {!Config.filter_useless_monadic_calls}.
+
+ TODO: use an option to disallow backward functions from updating the state.
+ TODO: a backward function which only gives back shared borrows shouldn't
+ update the state (state updates should only be used for mutable borrows,
+ with objects like Rc for instance).
+ *)
+ if
+ (not !Config.return_back_funs)
+ && !Config.filter_useless_monadic_calls
+ && outputs = [] && nstate = None
+ then (
(* No outputs - we do a small sanity check: the backward function
- * should have exactly the same number of inputs as the forward:
- * this number can be different only if the forward function returned
- * a value containing mutable borrows, which can't be the case... *)
+ should have exactly the same number of inputs as the forward:
+ this number can be different only if the forward function returned
+ a value containing mutable borrows, which can't be the case... *)
assert (List.length inputs = List.length fwd_inputs);
next_e)
else mk_let effect_info.can_fail output call next_e
diff --git a/compiler/SynthesizeSymbolic.ml b/compiler/SynthesizeSymbolic.ml
index efcf001a..4ec7524b 100644
--- a/compiler/SynthesizeSymbolic.ml
+++ b/compiler/SynthesizeSymbolic.ml
@@ -2,6 +2,7 @@ open Types
open TypesUtils
open Expressions
open Values
+open LlbcAst
open SymbolicAst
let mk_mplace (p : place) (ctx : Contexts.eval_ctx) : mplace =
@@ -92,6 +93,7 @@ let synthesize_symbolic_expansion_no_branching (sv : symbolic_value)
synthesize_symbolic_expansion sv place [ Some see ] el
let synthesize_function_call (call_id : call_id) (ctx : Contexts.eval_ctx)
+ (sg : fun_sig option) (regions_hierarchy : region_var_groups)
(abstractions : AbstractionId.id list) (generics : generic_args)
(args : typed_value list) (args_places : mplace option list)
(dest : symbolic_value) (dest_place : mplace option) (e : expression option)
@@ -102,6 +104,8 @@ let synthesize_function_call (call_id : call_id) (ctx : Contexts.eval_ctx)
{
call_id;
ctx;
+ sg;
+ regions_hierarchy;
abstractions;
generics;
args;
@@ -118,28 +122,30 @@ let synthesize_global_eval (gid : GlobalDeclId.id) (dest : symbolic_value)
Option.map (fun e -> EvalGlobal (gid, dest, e)) e
let synthesize_regular_function_call (fun_id : fun_id_or_trait_method_ref)
- (call_id : FunCallId.id) (ctx : Contexts.eval_ctx)
+ (call_id : FunCallId.id) (ctx : Contexts.eval_ctx) (sg : fun_sig)
+ (regions_hierarchy : region_var_groups)
(abstractions : AbstractionId.id list) (generics : generic_args)
(args : typed_value list) (args_places : mplace option list)
(dest : symbolic_value) (dest_place : mplace option) (e : expression option)
: expression option =
synthesize_function_call
(Fun (fun_id, call_id))
- ctx abstractions generics args args_places dest dest_place e
+ ctx (Some sg) regions_hierarchy abstractions generics args args_places dest
+ dest_place e
let synthesize_unary_op (ctx : Contexts.eval_ctx) (unop : unop)
(arg : typed_value) (arg_place : mplace option) (dest : symbolic_value)
(dest_place : mplace option) (e : expression option) : expression option =
let generics = empty_generic_args in
- synthesize_function_call (Unop unop) ctx [] generics [ arg ] [ arg_place ]
- dest dest_place e
+ synthesize_function_call (Unop unop) ctx None [] [] generics [ arg ]
+ [ arg_place ] dest dest_place e
let synthesize_binary_op (ctx : Contexts.eval_ctx) (binop : binop)
(arg0 : typed_value) (arg0_place : mplace option) (arg1 : typed_value)
(arg1_place : mplace option) (dest : symbolic_value)
(dest_place : mplace option) (e : expression option) : expression option =
let generics = empty_generic_args in
- synthesize_function_call (Binop binop) ctx [] generics [ arg0; arg1 ]
+ synthesize_function_call (Binop binop) ctx None [] [] generics [ arg0; arg1 ]
[ arg0_place; arg1_place ] dest dest_place e
let synthesize_end_abstraction (ctx : Contexts.eval_ctx) (abs : abs)
diff --git a/compiler/Translate.ml b/compiler/Translate.ml
index 631a5af9..5584fb9a 100644
--- a/compiler/Translate.ml
+++ b/compiler/Translate.ml
@@ -129,7 +129,7 @@ let translate_function_to_pure (trans_ctx : trans_ctx)
in
let sg =
- SymbolicToPure.translate_fun_sig_to_decomposed trans_ctx (FRegular def_id)
+ SymbolicToPure.translate_fun_sig_to_decomposed trans_ctx def_id
fdef.signature input_names
in
@@ -151,6 +151,7 @@ let translate_function_to_pure (trans_ctx : trans_ctx)
let ctx =
{
+ decls_ctx = trans_ctx;
SymbolicToPure.bid = None;
sg;
(* Will need to be updated for the backward functions *)