From 62cb926e76ef0c9fb048b0e340bdae5b9dd76a84 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Fri, 15 Dec 2023 14:06:16 +0100 Subject: Make progress on updating SymbolicToPure --- compiler/SymbolicToPure.ml | 169 ++++++++++++++++++++++++++++++--------------- 1 file changed, 112 insertions(+), 57 deletions(-) (limited to 'compiler') diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml index 456ec0f6..d62cc829 100644 --- a/compiler/SymbolicToPure.ml +++ b/compiler/SymbolicToPure.ml @@ -127,7 +127,15 @@ type bs_ctx = { trait_decls_ctx : trait_decls_context; trait_impls_ctx : trait_impls_context; fun_decl : A.fun_decl; - bid : T.RegionGroupId.id option; (** TODO: rename *) + bid : RegionGroupId.id option; + (** TODO: rename + + The id of the group region we are currently translating. + If we split the forward/backward functions, we set this id at the + very beginning of the translation. + If we don't split, we set it to `None`, then update it when we enter + an expression which is specific to a backward function. + *) sg : decomposed_fun_sig; (** Information about the function signature - useful in particular to translate [Panic] *) @@ -139,7 +147,7 @@ type bs_ctx = { var_counter : VarId.generator; state_var : VarId.id; (** The current state variable, in case the function is stateful *) - back_state_var : VarId.id; + back_state_vars : VarId.id RegionGroupId.Map.t; (** The additional input state variable received by a stateful backward function. When generating stateful functions, we generate code of the following form: @@ -163,16 +171,16 @@ type bs_ctx = { (** The input parameters for the forward function corresponding to the translated Rust inputs (no fuel, no state). *) - backward_inputs : var list T.RegionGroupId.Map.t; + backward_inputs : var list RegionGroupId.Map.t; (** The additional input parameters for the backward functions coming from the borrows consumed upon ending the lifetime (as a consequence those don't include the backward state, if there is one). *) - backward_outputs : var list T.RegionGroupId.Map.t; + backward_outputs : var list RegionGroupId.Map.t; (** The variables that the backward functions will output, corresponding to the borrows they give back (don't include the backward state) *) - loop_backward_outputs : var list T.RegionGroupId.Map.t option; + loop_backward_outputs : var list RegionGroupId.Map.t option; (** Same as {!backward_outputs}, but for loops (if we entered a loop). [None] if we are not inside a loop, [Some] otherwise (and whatever @@ -300,6 +308,13 @@ let typed_pattern_to_string (ctx : bs_ctx) (p : Pure.typed_pattern) : string = let env = bs_ctx_to_pure_fmt_env ctx in PrintPure.typed_pattern_to_string env p +let ctx_get_effect_info (ctx : bs_ctx) : fun_effect_info = + match ctx.bid with + | None -> ctx.sg.fwd_info.effect_info + | Some bid -> + let back_sg = RegionGroupId.Map.find bid ctx.sg.back_sg in + back_sg.effect_info + (* TODO: move *) let abs_to_string (ctx : bs_ctx) (abs : V.abs) : string = let env = bs_ctx_to_fmt_env ctx in @@ -1034,6 +1049,24 @@ let translate_fun_sig_to_decomposed (decls_ctx : C.decls_ctx) fwd_info; } +let mk_output_ty_from_effect_info (effect_info : fun_effect_info) (ty : ty) : ty + = + let output = + if effect_info.stateful then mk_simpl_tuple_ty [ mk_state_ty; ty ] else ty + in + if effect_info.can_fail then mk_result_ty output else output + +(** Compute the arrow types for all the backward functions *) +let compute_back_tys (dsg : Pure.decomposed_fun_sig) : ty list = + List.map + (fun (back_sg : back_sg_info) -> + let effect_info = back_sg.effect_info in + let inputs = dsg.fwd_inputs @ back_sg.inputs in + let output = mk_simpl_tuple_ty back_sg.outputs in + let output = mk_output_ty_from_effect_info effect_info output in + mk_arrows inputs output) + (RegionGroupId.Map.values dsg.back_sg) + let translate_fun_sig_from_decomposed (dsg : Pure.decomposed_fun_sig) (gid : RegionGroupId.id option) : fun_sig = let generics = dsg.generics in @@ -1050,27 +1083,13 @@ let translate_fun_sig_from_decomposed (dsg : Pure.decomposed_fun_sig) in (* Two cases depending on whether we split the forward/backward functions or not *) - let mk_output_ty (effect_info : fun_effect_info) output = - let output = - if effect_info.stateful then mk_simpl_tuple_ty [ mk_state_ty; output ] - else output - in - if effect_info.can_fail then mk_result_ty output else output - in + let mk_output_ty = mk_output_ty_from_effect_info in + let inputs, output = if !Config.return_back_funs then ( assert (gid = None); (* Compute the arrow types for all the backward functions *) - let back_tys = - List.map - (fun (back_sg : back_sg_info) -> - let effect_info = back_sg.effect_info in - let inputs = dsg.fwd_inputs @ back_sg.inputs in - let output = mk_simpl_tuple_ty back_sg.outputs in - let output = mk_output_ty effect_info output in - mk_arrows inputs output) - (RegionGroupId.Map.values dsg.back_sg) - in + let back_tys = compute_back_tys dsg in (* Group the forward output and the types of the backward functions *) let effect_info = dsg.fwd_info.effect_info in let output = mk_simpl_tuple_ty (dsg.fwd_output :: back_tys) in @@ -1584,30 +1603,43 @@ and translate_panic (ctx : bs_ctx) : texpression = * but it won't be true anymore once we translate individual blocks *) (* If we use a state monad, we need to add a lambda for the state variable *) (* Note that only forward functions return a state *) - let output_ty = - if ctx.inside_loop && Option.is_some ctx.bid then - (* We are synthesizing the backward function of a loop body *) - let bid = Option.get ctx.bid in - let back_vars = - T.RegionGroupId.Map.find bid (Option.get ctx.loop_backward_outputs) - in - let tys = List.map (fun (v : var) -> v.ty) back_vars in - mk_simpl_tuple_ty tys - else - (* Regular function, or forward function (the forward translation for - a loop has the same return type as the parent function) - *) - mk_simpl_tuple_ty ctx.sg.doutputs - in + let effect_info = ctx_get_effect_info ctx in (* TODO: we should use a [Fail] function *) - if ctx.sg.info.effect_info.stateful then - (* Create the [Fail] value *) - let ret_ty = mk_simpl_tuple_ty [ mk_state_ty; output_ty ] in - let ret_v = - mk_result_fail_texpression_with_error_id error_failure_id ret_ty + let mk_output output_ty = + if effect_info.stateful then + (* Create the [Fail] value *) + let ret_ty = mk_simpl_tuple_ty [ mk_state_ty; output_ty ] in + let ret_v = + mk_result_fail_texpression_with_error_id error_failure_id ret_ty + in + ret_v + else mk_result_fail_texpression_with_error_id error_failure_id output_ty + in + if ctx.inside_loop && Option.is_some ctx.bid then + (* We are synthesizing the backward function of a loop body *) + let bid = Option.get ctx.bid in + let back_vars = + T.RegionGroupId.Map.find bid (Option.get ctx.loop_backward_outputs) in - ret_v - else mk_result_fail_texpression_with_error_id error_failure_id output_ty + let tys = List.map (fun (v : var) -> v.ty) back_vars in + let output = mk_simpl_tuple_ty tys in + mk_output output + else + (* Regular function, or forward function (the forward translation for + a loop has the same return type as the parent function) + *) + match ctx.bid with + | None -> + if !Config.return_back_funs then + let back_tys = compute_back_tys ctx.sg in + let output = mk_simpl_tuple_ty (ctx.sg.fwd_output :: back_tys) in + mk_output output + else mk_output ctx.sg.fwd_output + | Some bid -> + let output = + mk_simpl_tuple_ty (RegionGroupId.Map.find bid ctx.sg.back_sg).outputs + in + mk_output output (** [opt_v]: the value to return, in case we translate a forward body *) and translate_return (ectx : C.eval_ctx) (opt_v : V.typed_value option) @@ -1641,7 +1673,7 @@ and translate_return (ectx : C.eval_ctx) (opt_v : V.typed_value option) * - error-monad: Return x * - state-error: Return (state, x) * *) - let effect_info = ctx.sg.info.effect_info in + let effect_info = ctx_get_effect_info ctx in let output = if effect_info.stateful then let state_rvalue = mk_state_texpression ctx.state_var in @@ -1695,7 +1727,7 @@ and translate_return_with_loop (loop_id : V.LoopId.id) (is_continue : bool) * effect - in particular, one manipulates a state iff the other does * the same. * *) - let effect_info = ctx.sg.info.effect_info in + let effect_info = ctx_get_effect_info ctx in let output = if effect_info.stateful then let state_rvalue = mk_state_texpression ctx.state_var in @@ -2550,24 +2582,50 @@ and translate_intro_symbolic (ectx : C.eval_ctx) (p : S.mplace option) and translate_forward_end (ectx : C.eval_ctx) (loop_input_values : V.typed_value S.symbolic_value_id_map option) - (e : S.expression) (back_e : S.expression S.region_group_id_map) + (fwd_e : S.expression) (back_e : S.expression S.region_group_id_map) (ctx : bs_ctx) : texpression = - (* Update the current state with the additional state received by the backward - function, if needs be, and lookup the proper expression *) - let translate_end ctx = + (* TODO: *) + assert (not !Config.return_back_funs); + + let translate_one_end ctx (bid : RegionGroupId.id option) = (* Update the current state with the additional state received by the backward function, if needs be, and lookup the proper expression *) let ctx, e = match ctx.bid with - | None -> (ctx, e) + | None -> + (* We are translating the forward function - nothing to do *) + (ctx, fwd_e) | Some bid -> - let ctx = { ctx with state_var = ctx.back_state_var } in + (* There are two cases here: + - if we split the fwd/backward functions, we simply need to update + the state + - if we don't split, we also need to wrap the expression in a + lambda, which introduces the additional inputs of the backward + function + *) + let back_state_var = RegionGroupId.Map.find bid ctx.back_state_vars in + let ctx = { ctx with state_var = back_state_var } in let e = T.RegionGroupId.Map.find bid back_e in (ctx, e) in translate_expression e ctx in + (* There are two cases, depending on whether we are splitting the forward/backward + functions or not. + + - if we split, then we simply need to translate the proper "end" expression, + that is the end of the forward function, or of the backward function we + are currently translating. + - if we don't split, then we need to translate the end of the forward + function (this is the value we will return) and generate the bodies + of the backward functions (which we will also return). + + Update the current state with the additional state received by the backward + function, if needs be, and lookup the proper expression. + *) + let translate_end ctx = failwith "TODO" in + (* If we are (re-)entering a loop, we need to introduce a call to the forward translation of the loop. *) match loop_input_values with @@ -2617,10 +2675,7 @@ and translate_forward_end (ectx : C.eval_ctx) in (* Introduce a fresh output value for the forward function *) - let ctx, output_var = - let output_ty = mk_simpl_tuple_ty ctx.fwd_sg.doutputs in - fresh_var None output_ty ctx - in + let ctx, output_var = fresh_var None ctx.sg.fwd_output ctx in let args, ctx, out_pats = let output_pat = mk_typed_pattern_from_var output_var None in @@ -2832,7 +2887,7 @@ and translate_loop (loop : S.loop) (ctx : bs_ctx) : texpression = (* Add the input state *) let input_state = - if ctx.sg.info.effect_info.stateful then Some ctx.state_var else None + if (ctx_get_effect_info ctx).stateful then Some ctx.state_var else None in (* Translate the loop body *) -- cgit v1.2.3