summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSon Ho2023-12-21 16:35:27 +0100
committerSon Ho2023-12-21 16:35:27 +0100
commitcf3eea59ee61f2341daf7248664b8be878f128af (patch)
treed2b08318a6dcbcc82773d85b130a4536f3c61e37
parentd9f91cfcd538525f024c6019d7c8250dda8d76fd (diff)
Update SymbolicToPure.ml for the loops
Diffstat (limited to '')
-rw-r--r--compiler/SymbolicToPure.ml221
1 files changed, 125 insertions, 96 deletions
diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml
index ef0a0bde..d3b0933c 100644
--- a/compiler/SymbolicToPure.ml
+++ b/compiler/SymbolicToPure.ml
@@ -125,6 +125,11 @@ type loop_info = {
(** The map from region group ids to the types of the values given back
by the corresponding loop abstractions.
*)
+ back_funs : texpression RegionGroupId.Map.t option;
+ (** Same as {!call_info.back_funs}.
+ Initialized with [None], gets updated to [Some] only if we merge
+ the fwd/back functions.
+ *)
}
[@@deriving show]
@@ -1123,45 +1128,25 @@ let mk_output_ty_from_effect_info (effect_info : fun_effect_info) (ty : ty) : ty
in
if effect_info.can_fail then mk_result_ty output else output
-(** Compute the arrow types for all the backward functions.
-
- TODO: merge with below?
- *)
-let compute_back_tys (dsg : Pure.decomposed_fun_sig) : ty list =
+(** Compute the arrow types for all the backward functions. *)
+let compute_back_tys (dsg : Pure.decomposed_fun_sig)
+ (subst : (generic_args * trait_instance_id) option) : ty list =
List.map
(fun (back_sg : back_sg_info) ->
let effect_info = back_sg.effect_info in
+ (* Compute *)
let inputs = List.map snd back_sg.inputs in
let output = mk_simpl_tuple_ty back_sg.outputs in
let output = mk_output_ty_from_effect_info effect_info output in
- mk_arrows inputs output)
+ let ty = mk_arrows inputs output in
+ (* Substitute - TODO: normalize *)
+ match subst with
+ | None -> ty
+ | Some (generics, tr_self) ->
+ let subst = make_subst_from_generics dsg.generics generics tr_self in
+ ty_substitute subst ty)
(RegionGroupId.Map.values dsg.back_sg)
-(** Return the instantiated pure signature of a backward function, in the
- case the forward/backward functions are merged (i.e., the forward functions
- return the backward functions).
- *)
-let translate_ret_back_inst_fun_sig_from_decomposed
- (dsg : Pure.decomposed_fun_sig) (generics : generic_args)
- (gid : RegionGroupId.id) : inst_fun_sig =
- assert !Config.return_back_funs;
- let mk_output_ty = mk_output_ty_from_effect_info in
- (* Lookup the signature information *)
- let back_sg = RegionGroupId.Map.find gid dsg.back_sg in
- let effect_info = back_sg.effect_info in
- (* Do not prepend the forward inputs *)
- let inputs = List.map snd back_sg.inputs in
- let output = mk_simpl_tuple_ty back_sg.outputs in
- let output = mk_output_ty effect_info output in
- (* Substitute the types *)
- let tr_self = UnknownTrait __FUNCTION__ in
- let subst = make_subst_from_generics dsg.generics generics tr_self in
- let subst = ty_substitute subst in
- let inputs = List.map subst inputs in
- let output = subst output in
- (* Return *)
- { inputs; output }
-
let translate_fun_sig_from_decomposed (dsg : Pure.decomposed_fun_sig)
(gid : RegionGroupId.id option) : fun_sig =
let generics = dsg.generics in
@@ -1184,7 +1169,7 @@ let translate_fun_sig_from_decomposed (dsg : Pure.decomposed_fun_sig)
if !Config.return_back_funs then (
assert (gid = None);
(* Compute the arrow types for all the backward functions *)
- let back_tys = compute_back_tys dsg in
+ let back_tys = compute_back_tys dsg None in
(* Group the forward output and the types of the backward functions *)
let effect_info = dsg.fwd_info.effect_info in
let output = mk_simpl_tuple_ty (dsg.fwd_output :: back_tys) in
@@ -1274,6 +1259,40 @@ let fresh_vars (vars : (string option * ty) list) (ctx : bs_ctx) :
bs_ctx * var list =
List.fold_left_map (fun ctx (name, ty) -> fresh_var name ty ctx) ctx vars
+(* Introduce variables for the backward functions *)
+let fresh_back_vars_for_current_fun (ctx : bs_ctx) : bs_ctx * var list =
+ (* We lookup the LLBC definition in an attempt to derive pretty names
+ for the backward functions. *)
+ let back_var_names =
+ let def_id = ctx.fun_decl.def_id in
+ let sg = ctx.fun_decl.signature in
+ let regions_hierarchy =
+ LlbcAstUtils.FunIdMap.find (FRegular def_id)
+ ctx.fun_ctx.regions_hierarchies
+ in
+ List.map
+ (fun (gid, _) ->
+ let rg = RegionGroupId.nth regions_hierarchy gid in
+ let region_names =
+ List.map
+ (fun rid -> (T.RegionVarId.nth sg.generics.regions rid).name)
+ rg.regions
+ in
+ let name =
+ match region_names with
+ | [] -> "back"
+ | [ Some r ] -> "back" ^ r
+ | _ ->
+ (* Concatenate all the region names *)
+ "back"
+ ^ String.concat "" (List.filter_map (fun x -> x) region_names)
+ in
+ Some name)
+ (RegionGroupId.Map.bindings ctx.sg.back_sg)
+ in
+ let back_vars = List.combine back_var_names (compute_back_tys ctx.sg None) in
+ fresh_vars back_vars ctx
+
let lookup_var_for_symbolic_value (sv : V.symbolic_value) (ctx : bs_ctx) : var =
match V.SymbolicValueId.Map.find_opt sv.sv_id ctx.sv_to_var with
| Some v -> v
@@ -1728,7 +1747,7 @@ and translate_panic (ctx : bs_ctx) : texpression =
match ctx.bid with
| None ->
if !Config.return_back_funs then
- let back_tys = compute_back_tys ctx.sg in
+ let back_tys = compute_back_tys ctx.sg None in
let output = mk_simpl_tuple_ty (ctx.sg.fwd_output :: back_tys) in
mk_output output
else mk_output ctx.sg.fwd_output
@@ -1883,22 +1902,9 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) :
fid call.regions_hierarchy sg
(List.map (fun _ -> None) sg.inputs)
in
- let gids =
- List.map
- (fun (g : T.region_var_group) -> g.id)
- call.regions_hierarchy
- in
- let back_sgs =
- List.map
- (translate_ret_back_inst_fun_sig_from_decomposed dsg generics)
- gids
- in
+ let tr_self = UnknownTrait __FUNCTION__ in
+ let back_tys = compute_back_tys dsg (Some (generics, tr_self)) in
(* Introduce variables for the backward functions *)
- let back_tys =
- List.map
- (fun (sg : inst_fun_sig) -> mk_arrows sg.inputs sg.output)
- back_sgs
- in
(* Compute a proper basename for the variables *)
let back_fun_name =
let name =
@@ -1934,6 +1940,11 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) :
let back_funs =
List.map (fun v -> mk_typed_pattern_from_var v None) back_vars
in
+ let gids =
+ List.map
+ (fun (g : T.region_var_group) -> g.id)
+ call.regions_hierarchy
+ in
let back_funs_map =
RegionGroupId.Map.of_list
(List.combine gids (List.map mk_texpression_from_var back_vars))
@@ -2338,6 +2349,8 @@ and translate_end_abstraction_loop (ectx : C.eval_ctx) (abs : V.abs)
(* Actually the same case as [SynthInput] *)
translate_end_abstraction_synth_input ectx abs e ctx rg_id
| V.LoopCall ->
+ (* We need to introduce a call to the backward function corresponding
+ to a forward call which happened earlier *)
let fun_id = E.FRegular ctx.fun_decl.def_id in
let effect_info =
get_fun_effect_info ctx.fun_ctx.fun_infos (FunId fun_id) (Some vloop_id)
@@ -2367,7 +2380,10 @@ and translate_end_abstraction_loop (ectx : C.eval_ctx) (abs : V.abs)
else ([], ctx, None)
in
(* Concatenate all the inputs *)
- let inputs = List.concat [ fwd_inputs; back_inputs; back_state ] in
+ let inputs =
+ if !Config.return_back_funs then List.concat [ back_inputs; back_state ]
+ else List.concat [ fwd_inputs; back_inputs; back_state ]
+ in
(* Retrieve the values given back by this function *)
let ctx, outputs = abs_to_given_back None abs ctx in
(* Group the output values together: first the updated inputs *)
@@ -2391,28 +2407,43 @@ and translate_end_abstraction_loop (ectx : C.eval_ctx) (abs : V.abs)
let ret_ty =
if effect_info.can_fail then mk_result_ty output.ty else output.ty
in
- let func_ty = mk_arrows input_tys ret_ty in
- let func = Fun (FromLlbc (FunId fun_id, Some loop_id, Some rg_id)) in
- let func = { id = FunOrOp func; generics } in
- let func = { e = Qualif func; ty = func_ty } in
+ (* Create the expression for the function:
+ - it is either a call to a top-level function, if we split the
+ forward/backward functions
+ - or a call to the variable we introduced for the backward function,
+ if we merge the forward/backward functions *)
+ let func =
+ if !Config.return_back_funs then
+ RegionGroupId.Map.find rg_id (Option.get loop_info.back_funs)
+ else
+ let func_ty = mk_arrows input_tys ret_ty in
+ let func = Fun (FromLlbc (FunId fun_id, Some loop_id, Some rg_id)) in
+ let func = { id = FunOrOp func; generics } in
+ { e = Qualif func; ty = func_ty }
+ in
let call = mk_apps func args in
(* **Optimization**:
- * =================
- * We do a small optimization here: if the backward function doesn't
- * have any output, we don't introduce any function call.
- * See the comment in {!Config.filter_useless_monadic_calls}.
- *
- * TODO: use an option to disallow backward functions from updating the state.
- * TODO: a backward function which only gives back shared borrows shouldn't
- * update the state (state updates should only be used for mutable borrows,
- * with objects like Rc for instance).
- *)
- if !Config.filter_useless_monadic_calls && outputs = [] && nstate = None
+ =================
+ We do a small optimization here in case we split the forward/backward
+ functions.
+ If the backward function doesn't have any output, we don't introduce
+ any function call.
+ See the comment in {!Config.filter_useless_monadic_calls}.
+
+ TODO: use an option to disallow backward functions from updating the state.
+ TODO: a backward function which only gives back shared borrows shouldn't
+ update the state (state updates should only be used for mutable borrows,
+ with objects like Rc for instance).
+ *)
+ if
+ (not !Config.return_back_funs)
+ && !Config.filter_useless_monadic_calls
+ && outputs = [] && nstate = None
then (
(* No outputs - we do a small sanity check: the backward function
- * should have exactly the same number of inputs as the forward:
- * this number can be different only if the forward function returned
- * a value containing mutable borrows, which can't be the case... *)
+ should have exactly the same number of inputs as the forward:
+ this number can be different only if the forward function returned
+ a value containing mutable borrows, which can't be the case... *)
assert (List.length inputs = List.length fwd_inputs);
next_e)
else
@@ -2860,35 +2891,7 @@ and translate_forward_end (ectx : C.eval_ctx)
(* Introduce variables for the backward functions.
We lookup the LLBC definition in an attempt to derive pretty names
for those functions. *)
- let back_var_names =
- let def_id = ctx.fun_decl.def_id in
- let sg = ctx.fun_decl.signature in
- let regions_hierarchy =
- LlbcAstUtils.FunIdMap.find (FRegular def_id)
- ctx.fun_ctx.regions_hierarchies
- in
- List.map
- (fun (gid, _) ->
- let rg = RegionGroupId.nth regions_hierarchy gid in
- let region_names =
- List.map
- (fun rid -> (T.RegionVarId.nth sg.generics.regions rid).name)
- rg.regions
- in
- let name =
- match region_names with
- | [] -> "back"
- | [ Some r ] -> "back" ^ r
- | _ ->
- (* Concatenate all the region names *)
- "back"
- ^ String.concat "" (List.filter_map (fun x -> x) region_names)
- in
- Some name)
- (RegionGroupId.Map.bindings ctx.sg.back_sg)
- in
- let back_vars = List.combine back_var_names (compute_back_tys ctx.sg) in
- let _, back_vars = fresh_vars back_vars ctx in
+ let _, back_vars = fresh_back_vars_for_current_fun ctx in
(* Create the return expressions *)
let vars = fwd_var :: back_vars in
@@ -2964,8 +2967,32 @@ and translate_forward_end (ectx : C.eval_ctx)
(* Introduce a fresh output value for the forward function *)
let ctx, output_var = fresh_var None ctx.sg.fwd_output ctx in
+ (* Introduce fresh variables for the backward functions of the loop.
+
+ For now, the backward functions of the loop are the same as the
+ backward functions of the outer function.
+ *)
+ let ctx, back_funs_map, back_funs =
+ if !Config.return_back_funs then
+ let ctx, back_vars = fresh_back_vars_for_current_fun ctx in
+ let back_funs =
+ List.map (fun v -> mk_typed_pattern_from_var v None) back_vars
+ in
+ let gids = RegionGroupId.Map.keys ctx.sg.back_sg in
+ let back_funs_map =
+ RegionGroupId.Map.of_list
+ (List.combine gids (List.map mk_texpression_from_var back_vars))
+ in
+ (ctx, Some back_funs_map, back_funs)
+ else (ctx, None, [])
+ in
+
+ (* Introduce patterns *)
let args, ctx, out_pats =
+ (* Create the pattern for the output value *)
let output_pat = mk_typed_pattern_from_var output_var None in
+ (* Add the returned backward functions (they might be empty) *)
+ let output_pat = mk_simpl_tuple_pattern (output_pat :: back_funs) in
(* Depending on the function effects:
* - add the fuel
@@ -2988,6 +3015,7 @@ and translate_forward_end (ectx : C.eval_ctx)
loop_info with
forward_inputs = Some args;
forward_output_no_state_no_result = Some output_var;
+ back_funs = back_funs_map;
}
in
let ctx =
@@ -3143,6 +3171,7 @@ and translate_loop (loop : S.loop) (ctx : bs_ctx) : texpression =
forward_inputs = None;
forward_output_no_state_no_result = None;
back_outputs = rg_to_given_back_tys;
+ back_funs = None;
}
in
let loops = LoopId.Map.add loop_id loop_info ctx.loops in