summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSon Ho2023-12-15 17:20:30 +0100
committerSon Ho2023-12-15 17:20:30 +0100
commitea583d9f0f5e4a1a687b70f0e04e875969462157 (patch)
tree6afa76ace8e4771c6757d2ae8b65e2493b6660e6
parent62cb926e76ef0c9fb048b0e340bdae5b9dd76a84 (diff)
Make good progress on updating SymbolicToPure
Diffstat (limited to '')
-rw-r--r--compiler/PrintPure.ml8
-rw-r--r--compiler/Pure.ml7
-rw-r--r--compiler/PureTypeCheck.ml6
-rw-r--r--compiler/PureUtils.ml23
-rw-r--r--compiler/SymbolicToPure.ml224
5 files changed, 226 insertions, 42 deletions
diff --git a/compiler/PrintPure.ml b/compiler/PrintPure.ml
index 2fe5843e..3a5ce513 100644
--- a/compiler/PrintPure.ml
+++ b/compiler/PrintPure.ml
@@ -592,6 +592,14 @@ let rec texpression_to_string (env : fmt_env) (inside : bool) (indent : string)
in
"[ " ^ String.concat ", " fields ^ " ]"
| _ -> raise (Failure "Unexpected"))
+ | Lambda _ ->
+ let pats, e = destruct_lambdas e in
+ let vars =
+ String.concat " " (List.map (typed_pattern_to_string env) pats)
+ in
+ let e = texpression_to_string env false indent indent_incr e in
+ let s = "λ " ^ vars ^ " => " ^ e in
+ if inside then "(" ^ s ^ ")" else s
| Meta (meta, e) -> (
let meta_s = emeta_to_string env meta in
let e = texpression_to_string env inside indent indent_incr e in
diff --git a/compiler/Pure.ml b/compiler/Pure.ml
index fb0509f4..eb6b00c8 100644
--- a/compiler/Pure.ml
+++ b/compiler/Pure.ml
@@ -728,6 +728,7 @@ type expression =
| Switch of texpression * switch_body
| Loop of loop (** See the comments for {!loop} *)
| StructUpdate of struct_update (** See the comments for {!struct_update} *)
+ | Lambda of typed_pattern * texpression (** [λ x => e] *)
| Meta of (emeta[@opaque]) * texpression (** Meta-information *)
and switch_body = If of texpression * texpression | Match of match_branch list
@@ -912,9 +913,9 @@ type fun_sig_info = {
[@@deriving show]
type back_sg_info = {
- inputs : ty list; (** The additional inputs of the backward function *)
- input_names : string option list;
- (** The optional names for the additional inputs *)
+ inputs : (string option * ty) list;
+ (** The additional inputs of the backward function *)
+ inputs_no_state : (string option * ty) list;
outputs : ty list;
(** The "decomposed" list of outputs.
diff --git a/compiler/PureTypeCheck.ml b/compiler/PureTypeCheck.ml
index a62a2361..3c1800a8 100644
--- a/compiler/PureTypeCheck.ml
+++ b/compiler/PureTypeCheck.ml
@@ -229,6 +229,12 @@ let rec check_texpression (ctx : tc_ctx) (e : texpression) : unit =
check_texpression ctx fe)
supd.updates
| _ -> raise (Failure "Unexpected"))
+ | Lambda (pat, e_next) ->
+ assert (e.ty = e_next.ty);
+ (* Check the pattern and register the introduced variables at the same time *)
+ let ctx = check_typed_pattern ctx pat in
+ (* Check the next expression *)
+ check_texpression ctx e_next
| Meta (_, e_next) ->
assert (e_next.ty = e.ty);
check_texpression ctx e_next
diff --git a/compiler/PureUtils.ml b/compiler/PureUtils.ml
index dfea255a..80b25641 100644
--- a/compiler/PureUtils.ml
+++ b/compiler/PureUtils.ml
@@ -221,6 +221,9 @@ let rec let_group_requires_parentheses (e : texpression) : bool =
if monadic then true else let_group_requires_parentheses next_e
| Switch (_, _) -> false
| Meta (_, next_e) -> let_group_requires_parentheses next_e
+ | Lambda (_, _) ->
+ (* Being conservative here *)
+ true
| Loop _ ->
(* Should have been eliminated *)
raise (Failure "Unreachable")
@@ -713,3 +716,23 @@ let type_decl_from_type_id_is_tuple_struct (ctx : TypesAnalysis.type_infos)
let info = TypeDeclId.Map.find id ctx in
info.is_tuple_struct
| TAssumed _ -> false
+
+let mk_lambda_from_var (var : var) (mp : mplace option) (e : texpression) :
+ texpression =
+ let ty = TArrow (var.ty, e.ty) in
+ let pat = PatVar (var, mp) in
+ let pat = { value = pat; ty = var.ty } in
+ let e = Lambda (pat, e) in
+ { e; ty }
+
+let mk_lambdas_from_vars (vars : var list) (mps : mplace option list)
+ (e : texpression) : texpression =
+ let vars = List.combine vars mps in
+ List.fold_left (fun e (v, mp) -> mk_lambda_from_var v mp e) e vars
+
+let rec destruct_lambdas (e : texpression) : typed_pattern list * texpression =
+ match e.e with
+ | Lambda (pat, e) ->
+ let pats, e = destruct_lambdas e in
+ (pat :: pats, e)
+ | _ -> ([], e)
diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml
index d62cc829..8e06db7c 100644
--- a/compiler/SymbolicToPure.ml
+++ b/compiler/SymbolicToPure.ml
@@ -121,9 +121,9 @@ type loop_info = {
(** Body synthesis context *)
type bs_ctx = {
- type_context : type_context;
- fun_context : fun_context;
- global_context : global_context;
+ type_context : type_context; (* TODO: rename *)
+ fun_context : fun_context; (* TODO: rename *)
+ global_context : global_context; (* TODO: rename *)
trait_decls_ctx : trait_decls_context;
trait_impls_ctx : trait_impls_context;
fun_decl : A.fun_decl;
@@ -148,7 +148,9 @@ type bs_ctx = {
state_var : VarId.id;
(** The current state variable, in case the function is stateful *)
back_state_vars : VarId.id RegionGroupId.Map.t;
- (** The additional input state variable received by a stateful backward function.
+ (** The additional input state variable received by a stateful backward function,
+ **in case we are splitting the forward/backward functions**.
+
When generating stateful functions, we generate code of the following
form:
@@ -161,7 +163,9 @@ type bs_ctx = {
When translating a backward function, we need at some point to update
[state_var] with [back_state_var], to account for the fact that the
state may have been updated by the caller between the call to the
- forward function and the call to the backward function.
+ forward function and the call to the backward function. We also need
+ to make sure we use the same variable in all the branches (because
+ this variable is quantified at the definition level).
*)
fuel0 : VarId.id;
(** The original fuel taken as input by the function (if we use fuel) *)
@@ -171,10 +175,20 @@ type bs_ctx = {
(** The input parameters for the forward function corresponding to the
translated Rust inputs (no fuel, no state).
*)
- backward_inputs : var list RegionGroupId.Map.t;
+ backward_inputs_no_state : var list RegionGroupId.Map.t;
(** The additional input parameters for the backward functions coming
from the borrows consumed upon ending the lifetime (as a consequence
those don't include the backward state, if there is one).
+
+ If we split the forward/backward functions: we initialize this map
+ when initializing the bs_ctx, because those variables are quantified
+ at the definition level. Otherwise, we initialize it upon diving
+ into the expressions which are specific to the backward functions.
+ *)
+ backward_inputs_with_state : var list RegionGroupId.Map.t;
+ (** All the additional input parameters for the backward functions.
+
+ Same remarks as for {!backward_inputs_no_state}.
*)
backward_outputs : var list RegionGroupId.Map.t;
(** The variables that the backward functions will output, corresponding
@@ -308,13 +322,17 @@ let typed_pattern_to_string (ctx : bs_ctx) (p : Pure.typed_pattern) : string =
let env = bs_ctx_to_pure_fmt_env ctx in
PrintPure.typed_pattern_to_string env p
-let ctx_get_effect_info (ctx : bs_ctx) : fun_effect_info =
- match ctx.bid with
+let ctx_get_effect_info_for_bid (ctx : bs_ctx) (bid : RegionGroupId.id option) :
+ fun_effect_info =
+ match bid with
| None -> ctx.sg.fwd_info.effect_info
| Some bid ->
let back_sg = RegionGroupId.Map.find bid ctx.sg.back_sg in
back_sg.effect_info
+let ctx_get_effect_info (ctx : bs_ctx) : fun_effect_info =
+ ctx_get_effect_info_for_bid ctx ctx.bid
+
(* TODO: move *)
let abs_to_string (ctx : bs_ctx) (abs : V.abs) : string =
let env = bs_ctx_to_fmt_env ctx in
@@ -1009,19 +1027,18 @@ let translate_fun_sig_to_decomposed (decls_ctx : C.decls_ctx)
get_fun_effect_info fun_infos (FunId fun_id) None (Some gid)
in
let inputs_no_state = translate_back_inputs_for_gid gid in
- let inputs_no_state_names =
- List.map (fun _ -> Some "ret") inputs_no_state
+ let inputs_no_state =
+ List.map (fun ty -> (Some "ret", ty)) inputs_no_state
in
- let state_ty, state_name =
- if back_effect_info.stateful then ([ mk_state_ty ], [ None ]) else ([], [])
+ let state =
+ if back_effect_info.stateful then [ (None, mk_state_ty) ] else []
in
- let inputs = inputs_no_state @ state_ty in
- let input_names = inputs_no_state_names @ state_name in
+ let inputs = inputs_no_state @ state in
let output_names, outputs = compute_back_outputs_for_gid gid in
let info =
{
inputs;
- input_names;
+ inputs_no_state;
outputs;
output_names;
effect_info = back_effect_info;
@@ -1061,7 +1078,7 @@ let compute_back_tys (dsg : Pure.decomposed_fun_sig) : ty list =
List.map
(fun (back_sg : back_sg_info) ->
let effect_info = back_sg.effect_info in
- let inputs = dsg.fwd_inputs @ back_sg.inputs in
+ let inputs = dsg.fwd_inputs @ List.map snd back_sg.inputs in
let output = mk_simpl_tuple_ty back_sg.outputs in
let output = mk_output_ty_from_effect_info effect_info output in
mk_arrows inputs output)
@@ -1105,14 +1122,14 @@ let translate_fun_sig_from_decomposed (dsg : Pure.decomposed_fun_sig)
| Some gid ->
let back_sg = RegionGroupId.Map.find gid dsg.back_sg in
let effect_info = back_sg.effect_info in
- let inputs = dsg.fwd_inputs @ back_sg.inputs in
+ let inputs = dsg.fwd_inputs @ List.map snd back_sg.inputs in
let output = mk_simpl_tuple_ty back_sg.outputs in
let output = mk_output_ty effect_info output in
(inputs, output)
in
{ generics; llbc_generics; preds; inputs; output; fwd_info; back_effect_info }
-let bs_ctx_fresh_state_var (ctx : bs_ctx) : bs_ctx * typed_pattern =
+let bs_ctx_fresh_state_var (ctx : bs_ctx) : bs_ctx * var * typed_pattern =
(* Generate the fresh variable *)
let id, var_counter = VarId.fresh ctx.var_counter in
let state_var =
@@ -1122,7 +1139,7 @@ let bs_ctx_fresh_state_var (ctx : bs_ctx) : bs_ctx * typed_pattern =
(* Update the context *)
let ctx = { ctx with var_counter; state_var = id } in
(* Return *)
- (ctx, state_pat)
+ (ctx, state_var, state_pat)
(** WARNING: do not call this function directly.
Call [fresh_named_var_for_symbolic_value] instead. *)
@@ -1776,7 +1793,7 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) :
let fuel = mk_fuel_input_as_list ctx effect_info in
if effect_info.stateful then
let state_var = mk_state_texpression ctx.state_var in
- let ctx, nstate_var = bs_ctx_fresh_state_var ctx in
+ let ctx, _, nstate_var = bs_ctx_fresh_state_var ctx in
(List.concat [ fuel; args; [ state_var ] ], ctx, Some nstate_var)
else (List.concat [ fuel; args ], ctx, None)
in
@@ -2010,7 +2027,7 @@ and translate_end_abstraction_fun_call (ectx : C.eval_ctx) (abs : V.abs)
let back_state, ctx, nstate =
if effect_info.stateful then
let back_state = mk_state_texpression ctx.state_var in
- let ctx, nstate = bs_ctx_fresh_state_var ctx in
+ let ctx, _, nstate = bs_ctx_fresh_state_var ctx in
([ back_state ], ctx, Some nstate)
else ([], ctx, None)
in
@@ -2115,15 +2132,15 @@ and translate_end_abstraction_synth_ret (ectx : C.eval_ctx) (abs : V.abs)
let-binding:
{[
let id_back x nx =
- let s = nx in // the name [s] is not important (only collision matters)
- ...
+ let s = nx in // the name [s] is not important (only collision matters)
+ ...
]}
This let-binding later gets inlined, during a micro-pass.
*)
(* First, retrieve the list of variables used for the inputs for the
* backward function *)
- let inputs = T.RegionGroupId.Map.find rg_id ctx.backward_inputs in
+ let inputs = T.RegionGroupId.Map.find rg_id ctx.backward_inputs_no_state in
(* Retrieve the values consumed upon ending the loans inside this
* abstraction: as there are no nested borrows, there should be none. *)
let consumed = abs_to_consumed ctx ectx abs in
@@ -2185,7 +2202,7 @@ and translate_end_abstraction_loop (ectx : C.eval_ctx) (abs : V.abs)
values consumed upon ending the abstraction (i.e., we don't use
[abs_to_consumed]) *)
let back_inputs_vars =
- T.RegionGroupId.Map.find rg_id ctx.backward_inputs
+ T.RegionGroupId.Map.find rg_id ctx.backward_inputs_no_state
in
let back_inputs = List.map mk_texpression_from_var back_inputs_vars in
(* If the function is stateful:
@@ -2195,7 +2212,7 @@ and translate_end_abstraction_loop (ectx : C.eval_ctx) (abs : V.abs)
let back_state, ctx, nstate =
if effect_info.stateful then
let back_state = mk_state_texpression ctx.state_var in
- let ctx, nstate = bs_ctx_fresh_state_var ctx in
+ let ctx, _, nstate = bs_ctx_fresh_state_var ctx in
([ back_state ], ctx, Some nstate)
else ([], ctx, None)
in
@@ -2590,25 +2607,69 @@ and translate_forward_end (ectx : C.eval_ctx)
let translate_one_end ctx (bid : RegionGroupId.id option) =
(* Update the current state with the additional state received by the backward
function, if needs be, and lookup the proper expression *)
- let ctx, e =
+ let ctx, e, finish =
match ctx.bid with
| None ->
(* We are translating the forward function - nothing to do *)
- (ctx, fwd_e)
+ (ctx, fwd_e, fun e -> e)
| Some bid ->
(* There are two cases here:
- if we split the fwd/backward functions, we simply need to update
- the state
+ the state.
- if we don't split, we also need to wrap the expression in a
lambda, which introduces the additional inputs of the backward
function
*)
- let back_state_var = RegionGroupId.Map.find bid ctx.back_state_vars in
- let ctx = { ctx with state_var = back_state_var } in
+ let ctx =
+ (* Introduce variables for the inputs and the state variable
+ and update the context. *)
+ if !Config.return_back_funs then
+ (* If the forward/backward functions are not split, we need
+ to introduce fresh variables for the additional inputs,
+ because they are locally introduced in a lambda *)
+ let back_sg = RegionGroupId.Map.find bid ctx.sg.back_sg in
+ let ctx = { ctx with bid = Some bid } in
+ let ctx, backward_inputs_no_state =
+ fresh_vars back_sg.inputs_no_state ctx
+ in
+ let ctx, backward_inputs_with_state =
+ if (ctx_get_effect_info ctx).stateful then
+ let ctx, var, _ = bs_ctx_fresh_state_var ctx in
+ (ctx, backward_inputs_no_state @ [ var ])
+ else (ctx, backward_inputs_no_state)
+ in
+ {
+ ctx with
+ backward_inputs_no_state =
+ RegionGroupId.Map.add bid backward_inputs_no_state
+ ctx.backward_inputs_no_state;
+ backward_inputs_with_state =
+ RegionGroupId.Map.add bid backward_inputs_with_state
+ ctx.backward_inputs_with_state;
+ }
+ else
+ (* Update the state variable *)
+ let back_state_var =
+ RegionGroupId.Map.find bid ctx.back_state_vars
+ in
+ { ctx with state_var = back_state_var }
+ in
+
let e = T.RegionGroupId.Map.find bid back_e in
- (ctx, e)
+ let finish e =
+ (* Wrap in lambdas if necessary *)
+ if !Config.return_back_funs then
+ let inputs =
+ RegionGroupId.Map.find bid ctx.backward_inputs_with_state
+ in
+ let places = List.map (fun _ -> None) inputs in
+ mk_lambdas_from_vars inputs places e
+ else e
+ in
+ (ctx, e, finish)
in
- translate_expression e ctx
+ let e = translate_expression e ctx in
+ finish e
in
(* There are two cases, depending on whether we are splitting the forward/backward
@@ -2624,7 +2685,87 @@ and translate_forward_end (ectx : C.eval_ctx)
Update the current state with the additional state received by the backward
function, if needs be, and lookup the proper expression.
*)
- let translate_end ctx = failwith "TODO" in
+ let translate_end ctx =
+ if !Config.return_back_funs then
+ (* Compute the output of the forward function *)
+ let fwd_effect_info = ctx.sg.fwd_info.effect_info in
+ let output_ty =
+ let ty = ctx.sg.fwd_output in
+ if fwd_effect_info.stateful then mk_simpl_tuple_ty [ mk_state_ty; ty ]
+ else ty
+ in
+ let ctx, fwd_var = fresh_var None output_ty ctx in
+ let ctx, state_var, state_pat =
+ if fwd_effect_info.stateful then
+ let ctx, var, pat = bs_ctx_fresh_state_var ctx in
+ (ctx, [ var ], [ pat ])
+ else (ctx, [], [])
+ in
+ let fwd_e = translate_one_end ctx None in
+
+ (* Introduce the backward functions *)
+ let back_el =
+ List.map
+ (fun ((gid, _) : RegionGroupId.id * back_sg_info) ->
+ translate_one_end ctx (Some gid))
+ (RegionGroupId.Map.bindings ctx.sg.back_sg)
+ in
+ (* Introduce variables for the backward functions.
+ We lookup the LLBC definition in an attempt to derive pretty names
+ for those functions. *)
+ let back_var_names =
+ let def_id = ctx.fun_decl.def_id in
+ let sg = ctx.fun_decl.signature in
+ let regions_hierarchy =
+ LlbcAstUtils.FunIdMap.find (FRegular def_id)
+ ctx.fun_context.regions_hierarchies
+ in
+ List.map
+ (fun (gid, _) ->
+ let rg = RegionGroupId.nth regions_hierarchy gid in
+ let region_names =
+ List.map
+ (fun rid -> (T.RegionVarId.nth sg.generics.regions rid).name)
+ rg.regions
+ in
+ let name =
+ match region_names with
+ | [] -> "back"
+ | [ Some r ] -> "back" ^ r
+ | _ ->
+ (* Concatenate all the region names *)
+ "back"
+ ^ String.concat "" (List.filter_map (fun x -> x) region_names)
+ in
+ Some name)
+ (RegionGroupId.Map.bindings ctx.sg.back_sg)
+ in
+ let back_vars = List.combine back_var_names (compute_back_tys ctx.sg) in
+ let _, back_vars = fresh_vars back_vars ctx in
+
+ (* Create the return expressions *)
+ let vars = fwd_var :: back_vars in
+ let vars = List.map mk_texpression_from_var vars in
+ let ret = mk_simpl_tuple_texpression vars in
+ let state_var = List.map mk_texpression_from_var state_var in
+ let ret = mk_simpl_tuple_texpression (state_var @ [ ret ]) in
+ let ret = mk_result_return_texpression ret in
+
+ (* Bind the expressions for the backward function and the expression
+ for the computation of the forward output *)
+ let e =
+ List.fold_right
+ (fun (var, back_e) e ->
+ mk_let false (mk_typed_pattern_from_var var None) back_e e)
+ (List.combine back_vars back_el)
+ ret
+ in
+ (* Bind the expression for the forward output *)
+ let fwd_var = mk_typed_pattern_from_var fwd_var None in
+ let pat = mk_simpl_tuple_pattern (state_pat @ [ fwd_var ]) in
+ mk_let fwd_effect_info.can_fail pat fwd_e e
+ else translate_one_end ctx ctx.bid
+ in
(* If we are (re-)entering a loop, we need to introduce a call to the
forward translation of the loop. *)
@@ -2687,7 +2828,7 @@ and translate_forward_end (ectx : C.eval_ctx)
let fuel = mk_fuel_input_as_list ctx effect_info in
if effect_info.stateful then
let state_var = mk_state_texpression ctx.state_var in
- let ctx, nstate_pat = bs_ctx_fresh_state_var ctx in
+ let ctx, _, nstate_pat = bs_ctx_fresh_state_var ctx in
( List.concat [ fuel; args; [ state_var ] ],
ctx,
[ nstate_pat; output_pat ] )
@@ -3025,8 +3166,8 @@ let translate_fun_decl (ctx : bs_ctx) (body : S.expression option) : fun_decl =
let def_id = def.def_id in
let llbc_name = def.name in
let name = name_to_string ctx llbc_name in
- (* Retrieve the signature *)
- let signature = ctx.sg in
+ (* Translate the signature *)
+ let signature = translate_fun_sig_from_decomposed ctx.sg ctx.bid in
let regions_hierarchy =
FunIdMap.find (FRegular def_id) ctx.fun_context.regions_hierarchies
in
@@ -3070,20 +3211,25 @@ let translate_fun_decl (ctx : bs_ctx) (body : S.expression option) : fun_decl =
match bid with
| None -> []
| Some back_id ->
+ assert (not !Config.return_back_funs);
let parents_ids =
list_ordered_ancestor_region_groups regions_hierarchy back_id
in
let backward_ids = List.append parents_ids [ back_id ] in
List.concat
(List.map
- (fun id -> T.RegionGroupId.Map.find id ctx.backward_inputs)
+ (fun id ->
+ T.RegionGroupId.Map.find id ctx.backward_inputs_no_state)
backward_ids)
in
(* Introduce the backward input state (the state at call site of the
* *backward* function), if necessary *)
let back_state =
if effect_info.stateful && Option.is_some bid then
- [ mk_state_var ctx.back_state_var ]
+ let state_var =
+ RegionGroupId.Map.find (Option.get bid) ctx.back_state_vars
+ in
+ [ mk_state_var state_var ]
else []
in
(* Group the inputs together *)