From ea583d9f0f5e4a1a687b70f0e04e875969462157 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Fri, 15 Dec 2023 17:20:30 +0100 Subject: Make good progress on updating SymbolicToPure --- compiler/PrintPure.ml | 8 ++ compiler/Pure.ml | 7 +- compiler/PureTypeCheck.ml | 6 ++ compiler/PureUtils.ml | 23 +++++ compiler/SymbolicToPure.ml | 224 +++++++++++++++++++++++++++++++++++++-------- 5 files changed, 226 insertions(+), 42 deletions(-) diff --git a/compiler/PrintPure.ml b/compiler/PrintPure.ml index 2fe5843e..3a5ce513 100644 --- a/compiler/PrintPure.ml +++ b/compiler/PrintPure.ml @@ -592,6 +592,14 @@ let rec texpression_to_string (env : fmt_env) (inside : bool) (indent : string) in "[ " ^ String.concat ", " fields ^ " ]" | _ -> raise (Failure "Unexpected")) + | Lambda _ -> + let pats, e = destruct_lambdas e in + let vars = + String.concat " " (List.map (typed_pattern_to_string env) pats) + in + let e = texpression_to_string env false indent indent_incr e in + let s = "λ " ^ vars ^ " => " ^ e in + if inside then "(" ^ s ^ ")" else s | Meta (meta, e) -> ( let meta_s = emeta_to_string env meta in let e = texpression_to_string env inside indent indent_incr e in diff --git a/compiler/Pure.ml b/compiler/Pure.ml index fb0509f4..eb6b00c8 100644 --- a/compiler/Pure.ml +++ b/compiler/Pure.ml @@ -728,6 +728,7 @@ type expression = | Switch of texpression * switch_body | Loop of loop (** See the comments for {!loop} *) | StructUpdate of struct_update (** See the comments for {!struct_update} *) + | Lambda of typed_pattern * texpression (** [λ x => e] *) | Meta of (emeta[@opaque]) * texpression (** Meta-information *) and switch_body = If of texpression * texpression | Match of match_branch list @@ -912,9 +913,9 @@ type fun_sig_info = { [@@deriving show] type back_sg_info = { - inputs : ty list; (** The additional inputs of the backward function *) - input_names : string option list; - (** The optional names for the additional inputs *) + inputs : (string option * ty) list; + (** The additional inputs of the backward function *) + inputs_no_state : (string option * ty) list; outputs : ty list; (** The "decomposed" list of outputs. diff --git a/compiler/PureTypeCheck.ml b/compiler/PureTypeCheck.ml index a62a2361..3c1800a8 100644 --- a/compiler/PureTypeCheck.ml +++ b/compiler/PureTypeCheck.ml @@ -229,6 +229,12 @@ let rec check_texpression (ctx : tc_ctx) (e : texpression) : unit = check_texpression ctx fe) supd.updates | _ -> raise (Failure "Unexpected")) + | Lambda (pat, e_next) -> + assert (e.ty = e_next.ty); + (* Check the pattern and register the introduced variables at the same time *) + let ctx = check_typed_pattern ctx pat in + (* Check the next expression *) + check_texpression ctx e_next | Meta (_, e_next) -> assert (e_next.ty = e.ty); check_texpression ctx e_next diff --git a/compiler/PureUtils.ml b/compiler/PureUtils.ml index dfea255a..80b25641 100644 --- a/compiler/PureUtils.ml +++ b/compiler/PureUtils.ml @@ -221,6 +221,9 @@ let rec let_group_requires_parentheses (e : texpression) : bool = if monadic then true else let_group_requires_parentheses next_e | Switch (_, _) -> false | Meta (_, next_e) -> let_group_requires_parentheses next_e + | Lambda (_, _) -> + (* Being conservative here *) + true | Loop _ -> (* Should have been eliminated *) raise (Failure "Unreachable") @@ -713,3 +716,23 @@ let type_decl_from_type_id_is_tuple_struct (ctx : TypesAnalysis.type_infos) let info = TypeDeclId.Map.find id ctx in info.is_tuple_struct | TAssumed _ -> false + +let mk_lambda_from_var (var : var) (mp : mplace option) (e : texpression) : + texpression = + let ty = TArrow (var.ty, e.ty) in + let pat = PatVar (var, mp) in + let pat = { value = pat; ty = var.ty } in + let e = Lambda (pat, e) in + { e; ty } + +let mk_lambdas_from_vars (vars : var list) (mps : mplace option list) + (e : texpression) : texpression = + let vars = List.combine vars mps in + List.fold_left (fun e (v, mp) -> mk_lambda_from_var v mp e) e vars + +let rec destruct_lambdas (e : texpression) : typed_pattern list * texpression = + match e.e with + | Lambda (pat, e) -> + let pats, e = destruct_lambdas e in + (pat :: pats, e) + | _ -> ([], e) diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml index d62cc829..8e06db7c 100644 --- a/compiler/SymbolicToPure.ml +++ b/compiler/SymbolicToPure.ml @@ -121,9 +121,9 @@ type loop_info = { (** Body synthesis context *) type bs_ctx = { - type_context : type_context; - fun_context : fun_context; - global_context : global_context; + type_context : type_context; (* TODO: rename *) + fun_context : fun_context; (* TODO: rename *) + global_context : global_context; (* TODO: rename *) trait_decls_ctx : trait_decls_context; trait_impls_ctx : trait_impls_context; fun_decl : A.fun_decl; @@ -148,7 +148,9 @@ type bs_ctx = { state_var : VarId.id; (** The current state variable, in case the function is stateful *) back_state_vars : VarId.id RegionGroupId.Map.t; - (** The additional input state variable received by a stateful backward function. + (** The additional input state variable received by a stateful backward function, + **in case we are splitting the forward/backward functions**. + When generating stateful functions, we generate code of the following form: @@ -161,7 +163,9 @@ type bs_ctx = { When translating a backward function, we need at some point to update [state_var] with [back_state_var], to account for the fact that the state may have been updated by the caller between the call to the - forward function and the call to the backward function. + forward function and the call to the backward function. We also need + to make sure we use the same variable in all the branches (because + this variable is quantified at the definition level). *) fuel0 : VarId.id; (** The original fuel taken as input by the function (if we use fuel) *) @@ -171,10 +175,20 @@ type bs_ctx = { (** The input parameters for the forward function corresponding to the translated Rust inputs (no fuel, no state). *) - backward_inputs : var list RegionGroupId.Map.t; + backward_inputs_no_state : var list RegionGroupId.Map.t; (** The additional input parameters for the backward functions coming from the borrows consumed upon ending the lifetime (as a consequence those don't include the backward state, if there is one). + + If we split the forward/backward functions: we initialize this map + when initializing the bs_ctx, because those variables are quantified + at the definition level. Otherwise, we initialize it upon diving + into the expressions which are specific to the backward functions. + *) + backward_inputs_with_state : var list RegionGroupId.Map.t; + (** All the additional input parameters for the backward functions. + + Same remarks as for {!backward_inputs_no_state}. *) backward_outputs : var list RegionGroupId.Map.t; (** The variables that the backward functions will output, corresponding @@ -308,13 +322,17 @@ let typed_pattern_to_string (ctx : bs_ctx) (p : Pure.typed_pattern) : string = let env = bs_ctx_to_pure_fmt_env ctx in PrintPure.typed_pattern_to_string env p -let ctx_get_effect_info (ctx : bs_ctx) : fun_effect_info = - match ctx.bid with +let ctx_get_effect_info_for_bid (ctx : bs_ctx) (bid : RegionGroupId.id option) : + fun_effect_info = + match bid with | None -> ctx.sg.fwd_info.effect_info | Some bid -> let back_sg = RegionGroupId.Map.find bid ctx.sg.back_sg in back_sg.effect_info +let ctx_get_effect_info (ctx : bs_ctx) : fun_effect_info = + ctx_get_effect_info_for_bid ctx ctx.bid + (* TODO: move *) let abs_to_string (ctx : bs_ctx) (abs : V.abs) : string = let env = bs_ctx_to_fmt_env ctx in @@ -1009,19 +1027,18 @@ let translate_fun_sig_to_decomposed (decls_ctx : C.decls_ctx) get_fun_effect_info fun_infos (FunId fun_id) None (Some gid) in let inputs_no_state = translate_back_inputs_for_gid gid in - let inputs_no_state_names = - List.map (fun _ -> Some "ret") inputs_no_state + let inputs_no_state = + List.map (fun ty -> (Some "ret", ty)) inputs_no_state in - let state_ty, state_name = - if back_effect_info.stateful then ([ mk_state_ty ], [ None ]) else ([], []) + let state = + if back_effect_info.stateful then [ (None, mk_state_ty) ] else [] in - let inputs = inputs_no_state @ state_ty in - let input_names = inputs_no_state_names @ state_name in + let inputs = inputs_no_state @ state in let output_names, outputs = compute_back_outputs_for_gid gid in let info = { inputs; - input_names; + inputs_no_state; outputs; output_names; effect_info = back_effect_info; @@ -1061,7 +1078,7 @@ let compute_back_tys (dsg : Pure.decomposed_fun_sig) : ty list = List.map (fun (back_sg : back_sg_info) -> let effect_info = back_sg.effect_info in - let inputs = dsg.fwd_inputs @ back_sg.inputs in + let inputs = dsg.fwd_inputs @ List.map snd back_sg.inputs in let output = mk_simpl_tuple_ty back_sg.outputs in let output = mk_output_ty_from_effect_info effect_info output in mk_arrows inputs output) @@ -1105,14 +1122,14 @@ let translate_fun_sig_from_decomposed (dsg : Pure.decomposed_fun_sig) | Some gid -> let back_sg = RegionGroupId.Map.find gid dsg.back_sg in let effect_info = back_sg.effect_info in - let inputs = dsg.fwd_inputs @ back_sg.inputs in + let inputs = dsg.fwd_inputs @ List.map snd back_sg.inputs in let output = mk_simpl_tuple_ty back_sg.outputs in let output = mk_output_ty effect_info output in (inputs, output) in { generics; llbc_generics; preds; inputs; output; fwd_info; back_effect_info } -let bs_ctx_fresh_state_var (ctx : bs_ctx) : bs_ctx * typed_pattern = +let bs_ctx_fresh_state_var (ctx : bs_ctx) : bs_ctx * var * typed_pattern = (* Generate the fresh variable *) let id, var_counter = VarId.fresh ctx.var_counter in let state_var = @@ -1122,7 +1139,7 @@ let bs_ctx_fresh_state_var (ctx : bs_ctx) : bs_ctx * typed_pattern = (* Update the context *) let ctx = { ctx with var_counter; state_var = id } in (* Return *) - (ctx, state_pat) + (ctx, state_var, state_pat) (** WARNING: do not call this function directly. Call [fresh_named_var_for_symbolic_value] instead. *) @@ -1776,7 +1793,7 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) : let fuel = mk_fuel_input_as_list ctx effect_info in if effect_info.stateful then let state_var = mk_state_texpression ctx.state_var in - let ctx, nstate_var = bs_ctx_fresh_state_var ctx in + let ctx, _, nstate_var = bs_ctx_fresh_state_var ctx in (List.concat [ fuel; args; [ state_var ] ], ctx, Some nstate_var) else (List.concat [ fuel; args ], ctx, None) in @@ -2010,7 +2027,7 @@ and translate_end_abstraction_fun_call (ectx : C.eval_ctx) (abs : V.abs) let back_state, ctx, nstate = if effect_info.stateful then let back_state = mk_state_texpression ctx.state_var in - let ctx, nstate = bs_ctx_fresh_state_var ctx in + let ctx, _, nstate = bs_ctx_fresh_state_var ctx in ([ back_state ], ctx, Some nstate) else ([], ctx, None) in @@ -2115,15 +2132,15 @@ and translate_end_abstraction_synth_ret (ectx : C.eval_ctx) (abs : V.abs) let-binding: {[ let id_back x nx = - let s = nx in // the name [s] is not important (only collision matters) - ... + let s = nx in // the name [s] is not important (only collision matters) + ... ]} This let-binding later gets inlined, during a micro-pass. *) (* First, retrieve the list of variables used for the inputs for the * backward function *) - let inputs = T.RegionGroupId.Map.find rg_id ctx.backward_inputs in + let inputs = T.RegionGroupId.Map.find rg_id ctx.backward_inputs_no_state in (* Retrieve the values consumed upon ending the loans inside this * abstraction: as there are no nested borrows, there should be none. *) let consumed = abs_to_consumed ctx ectx abs in @@ -2185,7 +2202,7 @@ and translate_end_abstraction_loop (ectx : C.eval_ctx) (abs : V.abs) values consumed upon ending the abstraction (i.e., we don't use [abs_to_consumed]) *) let back_inputs_vars = - T.RegionGroupId.Map.find rg_id ctx.backward_inputs + T.RegionGroupId.Map.find rg_id ctx.backward_inputs_no_state in let back_inputs = List.map mk_texpression_from_var back_inputs_vars in (* If the function is stateful: @@ -2195,7 +2212,7 @@ and translate_end_abstraction_loop (ectx : C.eval_ctx) (abs : V.abs) let back_state, ctx, nstate = if effect_info.stateful then let back_state = mk_state_texpression ctx.state_var in - let ctx, nstate = bs_ctx_fresh_state_var ctx in + let ctx, _, nstate = bs_ctx_fresh_state_var ctx in ([ back_state ], ctx, Some nstate) else ([], ctx, None) in @@ -2590,25 +2607,69 @@ and translate_forward_end (ectx : C.eval_ctx) let translate_one_end ctx (bid : RegionGroupId.id option) = (* Update the current state with the additional state received by the backward function, if needs be, and lookup the proper expression *) - let ctx, e = + let ctx, e, finish = match ctx.bid with | None -> (* We are translating the forward function - nothing to do *) - (ctx, fwd_e) + (ctx, fwd_e, fun e -> e) | Some bid -> (* There are two cases here: - if we split the fwd/backward functions, we simply need to update - the state + the state. - if we don't split, we also need to wrap the expression in a lambda, which introduces the additional inputs of the backward function *) - let back_state_var = RegionGroupId.Map.find bid ctx.back_state_vars in - let ctx = { ctx with state_var = back_state_var } in + let ctx = + (* Introduce variables for the inputs and the state variable + and update the context. *) + if !Config.return_back_funs then + (* If the forward/backward functions are not split, we need + to introduce fresh variables for the additional inputs, + because they are locally introduced in a lambda *) + let back_sg = RegionGroupId.Map.find bid ctx.sg.back_sg in + let ctx = { ctx with bid = Some bid } in + let ctx, backward_inputs_no_state = + fresh_vars back_sg.inputs_no_state ctx + in + let ctx, backward_inputs_with_state = + if (ctx_get_effect_info ctx).stateful then + let ctx, var, _ = bs_ctx_fresh_state_var ctx in + (ctx, backward_inputs_no_state @ [ var ]) + else (ctx, backward_inputs_no_state) + in + { + ctx with + backward_inputs_no_state = + RegionGroupId.Map.add bid backward_inputs_no_state + ctx.backward_inputs_no_state; + backward_inputs_with_state = + RegionGroupId.Map.add bid backward_inputs_with_state + ctx.backward_inputs_with_state; + } + else + (* Update the state variable *) + let back_state_var = + RegionGroupId.Map.find bid ctx.back_state_vars + in + { ctx with state_var = back_state_var } + in + let e = T.RegionGroupId.Map.find bid back_e in - (ctx, e) + let finish e = + (* Wrap in lambdas if necessary *) + if !Config.return_back_funs then + let inputs = + RegionGroupId.Map.find bid ctx.backward_inputs_with_state + in + let places = List.map (fun _ -> None) inputs in + mk_lambdas_from_vars inputs places e + else e + in + (ctx, e, finish) in - translate_expression e ctx + let e = translate_expression e ctx in + finish e in (* There are two cases, depending on whether we are splitting the forward/backward @@ -2624,7 +2685,87 @@ and translate_forward_end (ectx : C.eval_ctx) Update the current state with the additional state received by the backward function, if needs be, and lookup the proper expression. *) - let translate_end ctx = failwith "TODO" in + let translate_end ctx = + if !Config.return_back_funs then + (* Compute the output of the forward function *) + let fwd_effect_info = ctx.sg.fwd_info.effect_info in + let output_ty = + let ty = ctx.sg.fwd_output in + if fwd_effect_info.stateful then mk_simpl_tuple_ty [ mk_state_ty; ty ] + else ty + in + let ctx, fwd_var = fresh_var None output_ty ctx in + let ctx, state_var, state_pat = + if fwd_effect_info.stateful then + let ctx, var, pat = bs_ctx_fresh_state_var ctx in + (ctx, [ var ], [ pat ]) + else (ctx, [], []) + in + let fwd_e = translate_one_end ctx None in + + (* Introduce the backward functions *) + let back_el = + List.map + (fun ((gid, _) : RegionGroupId.id * back_sg_info) -> + translate_one_end ctx (Some gid)) + (RegionGroupId.Map.bindings ctx.sg.back_sg) + in + (* Introduce variables for the backward functions. + We lookup the LLBC definition in an attempt to derive pretty names + for those functions. *) + let back_var_names = + let def_id = ctx.fun_decl.def_id in + let sg = ctx.fun_decl.signature in + let regions_hierarchy = + LlbcAstUtils.FunIdMap.find (FRegular def_id) + ctx.fun_context.regions_hierarchies + in + List.map + (fun (gid, _) -> + let rg = RegionGroupId.nth regions_hierarchy gid in + let region_names = + List.map + (fun rid -> (T.RegionVarId.nth sg.generics.regions rid).name) + rg.regions + in + let name = + match region_names with + | [] -> "back" + | [ Some r ] -> "back" ^ r + | _ -> + (* Concatenate all the region names *) + "back" + ^ String.concat "" (List.filter_map (fun x -> x) region_names) + in + Some name) + (RegionGroupId.Map.bindings ctx.sg.back_sg) + in + let back_vars = List.combine back_var_names (compute_back_tys ctx.sg) in + let _, back_vars = fresh_vars back_vars ctx in + + (* Create the return expressions *) + let vars = fwd_var :: back_vars in + let vars = List.map mk_texpression_from_var vars in + let ret = mk_simpl_tuple_texpression vars in + let state_var = List.map mk_texpression_from_var state_var in + let ret = mk_simpl_tuple_texpression (state_var @ [ ret ]) in + let ret = mk_result_return_texpression ret in + + (* Bind the expressions for the backward function and the expression + for the computation of the forward output *) + let e = + List.fold_right + (fun (var, back_e) e -> + mk_let false (mk_typed_pattern_from_var var None) back_e e) + (List.combine back_vars back_el) + ret + in + (* Bind the expression for the forward output *) + let fwd_var = mk_typed_pattern_from_var fwd_var None in + let pat = mk_simpl_tuple_pattern (state_pat @ [ fwd_var ]) in + mk_let fwd_effect_info.can_fail pat fwd_e e + else translate_one_end ctx ctx.bid + in (* If we are (re-)entering a loop, we need to introduce a call to the forward translation of the loop. *) @@ -2687,7 +2828,7 @@ and translate_forward_end (ectx : C.eval_ctx) let fuel = mk_fuel_input_as_list ctx effect_info in if effect_info.stateful then let state_var = mk_state_texpression ctx.state_var in - let ctx, nstate_pat = bs_ctx_fresh_state_var ctx in + let ctx, _, nstate_pat = bs_ctx_fresh_state_var ctx in ( List.concat [ fuel; args; [ state_var ] ], ctx, [ nstate_pat; output_pat ] ) @@ -3025,8 +3166,8 @@ let translate_fun_decl (ctx : bs_ctx) (body : S.expression option) : fun_decl = let def_id = def.def_id in let llbc_name = def.name in let name = name_to_string ctx llbc_name in - (* Retrieve the signature *) - let signature = ctx.sg in + (* Translate the signature *) + let signature = translate_fun_sig_from_decomposed ctx.sg ctx.bid in let regions_hierarchy = FunIdMap.find (FRegular def_id) ctx.fun_context.regions_hierarchies in @@ -3070,20 +3211,25 @@ let translate_fun_decl (ctx : bs_ctx) (body : S.expression option) : fun_decl = match bid with | None -> [] | Some back_id -> + assert (not !Config.return_back_funs); let parents_ids = list_ordered_ancestor_region_groups regions_hierarchy back_id in let backward_ids = List.append parents_ids [ back_id ] in List.concat (List.map - (fun id -> T.RegionGroupId.Map.find id ctx.backward_inputs) + (fun id -> + T.RegionGroupId.Map.find id ctx.backward_inputs_no_state) backward_ids) in (* Introduce the backward input state (the state at call site of the * *backward* function), if necessary *) let back_state = if effect_info.stateful && Option.is_some bid then - [ mk_state_var ctx.back_state_var ] + let state_var = + RegionGroupId.Map.find (Option.get bid) ctx.back_state_vars + in + [ mk_state_var state_var ] else [] in (* Group the inputs together *) -- cgit v1.2.3