From bc154dda94c44b3ae67a3b04d3866cc473aead32 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Fri, 8 Mar 2024 13:41:57 +0100 Subject: Remove the option to split fwd/back functions and update SymbolicToPure --- compiler/SymbolicToPure.ml | 936 ++++++++++++++++++--------------------------- 1 file changed, 372 insertions(+), 564 deletions(-) (limited to 'compiler/SymbolicToPure.ml') diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml index 3a50e495..859d6f17 100644 --- a/compiler/SymbolicToPure.ml +++ b/compiler/SymbolicToPure.ml @@ -805,11 +805,9 @@ let bs_ctx_register_forward_call (call_id : V.FunCallId.id) (forward : S.call) that we need to call. This function may be [None] if it has to be ignored (because it does nothing). *) -let bs_ctx_register_backward_call (abs : V.abs) (effect_info : fun_effect_info) - (call_id : V.FunCallId.id) (back_id : T.RegionGroupId.id) - (inherited_args : texpression list) (back_args : texpression list) - (generics : generic_args) (output_ty : ty) (ctx : bs_ctx) : - bs_ctx * texpression option = +let bs_ctx_register_backward_call (abs : V.abs) (call_id : V.FunCallId.id) + (back_id : T.RegionGroupId.id) (back_args : texpression list) (ctx : bs_ctx) + : bs_ctx * texpression option = (* Insert the abstraction in the call informations *) let info = V.FunCallId.Map.find call_id ctx.calls in let calls = V.FunCallId.Map.add call_id info ctx.calls in @@ -819,29 +817,9 @@ let bs_ctx_register_backward_call (abs : V.abs) (effect_info : fun_effect_info) let abstractions = V.AbstractionId.Map.add abs.abs_id (abs, back_args) abstractions in - (* Compute the expression corresponding to the function *) - let func = - if !Config.return_back_funs then - (* Lookup the variable introduced for the backward function *) - RegionGroupId.Map.find back_id (Option.get info.back_funs) - else - (* Retrieve the fun_id *) - let fun_id = - match info.forward.call_id with - | S.Fun (fid, _) -> - let fid = translate_fun_id_or_trait_method_ref ctx fid in - Fun (FromLlbc (fid, None, Some back_id)) - | S.Unop _ | S.Binop _ -> raise (Failure "Unreachable") - in - let args = List.append inherited_args back_args in - let input_tys = (List.map (fun (x : texpression) -> x.ty)) args in - let ret_ty = - if effect_info.can_fail then mk_result_ty output_ty else output_ty - in - let func_ty = mk_arrows input_tys ret_ty in - let func = { id = FunOrOp fun_id; generics } in - Some { e = Qualif func; ty = func_ty } - in + (* Compute the expression corresponding to the function. + We simply lookup the variable introduced for the backward function. *) + let func = RegionGroupId.Map.find back_id (Option.get info.back_funs) in (* Update the context and return *) ({ ctx with calls; abstractions }, func) @@ -1124,20 +1102,34 @@ let translate_fun_sig_with_regions_hierarchy_to_decomposed let inputs_no_state = List.map (fun ty -> (Some "ret", ty)) inputs_no_state in - (* In case we merge the forward/backward functions: - we consider the backward function as stateful and potentially failing + (* We consider a backward function as stateful and potentially failing **only if it has inputs** (for the "potentially failing": if it has not inputs, we directly evaluate it in the body of the forward function). + + For instance, we do the following: + {[ + // Rust + fn push(v : &mut Vec, x : T) { ... } + + (* Generated code: before doing unit elimination. + We return (), as well as the backward function; as the backward + function doesn't consume any inputs, it is a value that we compute + directly in the body of [push]. + *) + let push T (v : Vec T) (x : T) : Result (() * Vec T) = ... + + (* Generated code: after doing unit elimination, if we simplify the merged + fwd/back functions (see below). *) + let push T (v : Vec T) (x : T) : Result (Vec T) = ... + ]} *) let back_effect_info = - if !Config.return_back_funs then - let b = inputs_no_state <> [] in - { - back_effect_info with - stateful = back_effect_info.stateful && b; - can_fail = back_effect_info.can_fail && b; - } - else back_effect_info + let b = inputs_no_state <> [] in + { + back_effect_info with + stateful = back_effect_info.stateful && b; + can_fail = back_effect_info.can_fail && b; + } in let state = if back_effect_info.stateful then [ (None, mk_state_ty) ] else [] @@ -1145,8 +1137,7 @@ let translate_fun_sig_with_regions_hierarchy_to_decomposed let inputs = inputs_no_state @ state in let output_names, outputs = compute_back_outputs_for_gid gid in let filter = - !Config.simplify_merged_fwd_backs - && !Config.return_back_funs && inputs = [] && outputs = [] + !Config.simplify_merged_fwd_backs && inputs = [] && outputs = [] in let info = { @@ -1186,7 +1177,7 @@ let translate_fun_sig_with_regions_hierarchy_to_decomposed } in let ignore_output = - if !Config.return_back_funs && !Config.simplify_merged_fwd_backs then + if !Config.simplify_merged_fwd_backs then ty_is_unit fwd_output && List.exists (fun (info : back_sg_info) -> not info.filter) @@ -1296,10 +1287,10 @@ let compute_back_tys (dsg : Pure.decomposed_fun_sig) (subst : (generic_args * trait_instance_id) option) : ty option list = List.map (Option.map snd) (compute_back_tys_with_info dsg subst) -(** In case we merge the fwd/back functions: compute the output type of - a function, from a decomposed signature. *) +(** Compute the output type of a function, from a decomposed signature + (the output type contains the type of the value returned by the forward + function as well as the types of the returned backward functions). *) let compute_output_ty_from_decomposed (dsg : Pure.decomposed_fun_sig) : ty = - assert !Config.return_back_funs; (* Compute the arrow types for all the backward functions *) let back_tys = List.filter_map (fun x -> x) (compute_back_tys dsg None) in (* Group the forward output and the types of the backward functions *) @@ -1315,8 +1306,8 @@ let compute_output_ty_from_decomposed (dsg : Pure.decomposed_fun_sig) : ty = in mk_output_ty_from_effect_info effect_info output -let translate_fun_sig_from_decomposed (dsg : Pure.decomposed_fun_sig) - (gid : RegionGroupId.id option) : fun_sig = +let translate_fun_sig_from_decomposed (dsg : Pure.decomposed_fun_sig) : fun_sig + = let generics = dsg.generics in let llbc_generics = dsg.llbc_generics in let preds = dsg.preds in @@ -1329,27 +1320,10 @@ let translate_fun_sig_from_decomposed (dsg : Pure.decomposed_fun_sig) (gid, info.effect_info)) (RegionGroupId.Map.bindings dsg.back_sg)) in - let mk_output_ty = mk_output_ty_from_effect_info in let inputs, output = - (* Two cases depending on whether we split the forward/backward functions or not *) - if !Config.return_back_funs then ( - assert (gid = None); - let output = compute_output_ty_from_decomposed dsg in - let inputs = dsg.fwd_inputs in - (inputs, output)) - else - match gid with - | None -> - let effect_info = dsg.fwd_info.effect_info in - let output = mk_output_ty effect_info dsg.fwd_output in - (dsg.fwd_inputs, output) - | Some gid -> - let back_sg = RegionGroupId.Map.find gid dsg.back_sg in - let effect_info = back_sg.effect_info in - let inputs = dsg.fwd_inputs @ List.map snd back_sg.inputs in - let output = mk_simpl_tuple_ty back_sg.outputs in - let output = mk_output_ty effect_info output in - (inputs, output) + let output = compute_output_ty_from_decomposed dsg in + let inputs = dsg.fwd_inputs in + (inputs, output) in { generics; llbc_generics; preds; inputs; output; fwd_info; back_effect_info } @@ -1933,16 +1907,14 @@ and translate_panic (ctx : bs_ctx) : texpression = *) match ctx.bid with | None -> - if !Config.return_back_funs then - let back_tys = compute_back_tys ctx.sg None in - let back_tys = List.filter_map (fun x -> x) back_tys in - let tys = - if ctx.sg.fwd_info.ignore_output then back_tys - else ctx.sg.fwd_output :: back_tys - in - let output = mk_simpl_tuple_ty tys in - mk_output output - else mk_output ctx.sg.fwd_output + let back_tys = compute_back_tys ctx.sg None in + let back_tys = List.filter_map (fun x -> x) back_tys in + let tys = + if ctx.sg.fwd_info.ignore_output then back_tys + else ctx.sg.fwd_output :: back_tys + in + let output = mk_simpl_tuple_ty tys in + mk_output output | Some bid -> let output = mk_simpl_tuple_ty (RegionGroupId.Map.find bid ctx.sg.back_sg).outputs @@ -2080,107 +2052,103 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) : (List.concat [ fuel; args; [ state_var ] ], ctx, Some nstate_var) else (List.concat [ fuel; args ], ctx, None) in - (* If we do not split the forward/backward functions: generate the - variables for the backward functions returned by the forward + (* Generate the variables for the backward functions returned by the forward function. *) let ctx, ignore_fwd_output, back_funs_map, back_funs = - if !Config.return_back_funs then ( - (* We need to compute the signatures of the backward functions. *) - let sg = Option.get call.sg in - let decls_ctx = ctx.decls_ctx in - let dsg = - translate_fun_sig_with_regions_hierarchy_to_decomposed decls_ctx - fid call.regions_hierarchy sg - (List.map (fun _ -> None) sg.inputs) - in - log#ldebug - (lazy ("dsg.generics:\n" ^ show_generic_params dsg.generics)); - let tr_self, all_generics = - match call.trait_method_generics with - | None -> (UnknownTrait __FUNCTION__, generics) - | Some (all_generics, tr_self) -> - let all_generics = - ctx_translate_fwd_generic_args ctx all_generics - in - let tr_self = - translate_fwd_trait_instance_id ctx.type_ctx.type_infos - tr_self + (* We need to compute the signatures of the backward functions. *) + let sg = Option.get call.sg in + let decls_ctx = ctx.decls_ctx in + let dsg = + translate_fun_sig_with_regions_hierarchy_to_decomposed decls_ctx fid + call.regions_hierarchy sg + (List.map (fun _ -> None) sg.inputs) + in + log#ldebug + (lazy ("dsg.generics:\n" ^ show_generic_params dsg.generics)); + let tr_self, all_generics = + match call.trait_method_generics with + | None -> (UnknownTrait __FUNCTION__, generics) + | Some (all_generics, tr_self) -> + let all_generics = + ctx_translate_fwd_generic_args ctx all_generics + in + let tr_self = + translate_fwd_trait_instance_id ctx.type_ctx.type_infos + tr_self + in + (tr_self, all_generics) + in + let back_tys = + compute_back_tys_with_info dsg (Some (all_generics, tr_self)) + in + (* Introduce variables for the backward functions *) + (* Compute a proper basename for the variables *) + let back_fun_name = + let name = + match fid with + | FunId (FAssumed fid) -> ( + match fid with + | BoxNew -> "box_new" + | BoxFree -> "box_free" + | ArrayRepeat -> "array_repeat" + | ArrayIndexShared -> "index_shared" + | ArrayIndexMut -> "index_mut" + | ArrayToSliceShared -> "to_slice_shared" + | ArrayToSliceMut -> "to_slice_mut" + | SliceIndexShared -> "index_shared" + | SliceIndexMut -> "index_mut") + | FunId (FRegular fid) | TraitMethod (_, _, fid) -> ( + let decl = + FunDeclId.Map.find fid ctx.fun_ctx.llbc_fun_decls in - (tr_self, all_generics) + match Collections.List.last decl.name with + | PeIdent (s, _) -> s + | PeImpl _ -> + (* We shouldn't get there *) + raise (Failure "Unexpected")) in - let back_tys = - compute_back_tys_with_info dsg (Some (all_generics, tr_self)) - in - (* Introduce variables for the backward functions *) - (* Compute a proper basename for the variables *) - let back_fun_name = - let name = - match fid with - | FunId (FAssumed fid) -> ( - match fid with - | BoxNew -> "box_new" - | BoxFree -> "box_free" - | ArrayRepeat -> "array_repeat" - | ArrayIndexShared -> "index_shared" - | ArrayIndexMut -> "index_mut" - | ArrayToSliceShared -> "to_slice_shared" - | ArrayToSliceMut -> "to_slice_mut" - | SliceIndexShared -> "index_shared" - | SliceIndexMut -> "index_mut") - | FunId (FRegular fid) | TraitMethod (_, _, fid) -> ( - let decl = - FunDeclId.Map.find fid ctx.fun_ctx.llbc_fun_decls - in - match Collections.List.last decl.name with - | PeIdent (s, _) -> s - | PeImpl _ -> - (* We shouldn't get there *) - raise (Failure "Unexpected")) - in - name ^ "_back" - in - let ctx, back_vars = - fresh_opt_vars - (List.map - (fun ty -> - match ty with - | None -> None - | Some (back_sg, ty) -> - (* We insert a name for the variable only if the function - can fail: if it can fail, it means the call returns a backward - function. Otherwise, we it directly returns the value given - back by the backward function, which means we shouldn't - give it a name like "back..." (it doesn't make sense) *) - let name = - if back_sg.effect_info.can_fail then - Some back_fun_name - else None - in - Some (name, ty)) - back_tys) - ctx - in - let back_funs = - List.filter_map - (fun v -> - match v with - | None -> None - | Some v -> Some (mk_typed_pattern_from_var v None)) - back_vars - in - let gids = - List.map - (fun (g : T.region_var_group) -> g.id) - call.regions_hierarchy - in - let back_vars = - List.map (Option.map mk_texpression_from_var) back_vars - in - let back_funs_map = - RegionGroupId.Map.of_list (List.combine gids back_vars) - in - (ctx, dsg.fwd_info.ignore_output, Some back_funs_map, back_funs)) - else (ctx, false, None, []) + name ^ "_back" + in + let ctx, back_vars = + fresh_opt_vars + (List.map + (fun ty -> + match ty with + | None -> None + | Some (back_sg, ty) -> + (* We insert a name for the variable only if the function + can fail: if it can fail, it means the call returns a backward + function. Otherwise, we it directly returns the value given + back by the backward function, which means we shouldn't + give it a name like "back..." (it doesn't make sense) *) + let name = + if back_sg.effect_info.can_fail then Some back_fun_name + else None + in + Some (name, ty)) + back_tys) + ctx + in + let back_funs = + List.filter_map + (fun v -> + match v with + | None -> None + | Some v -> Some (mk_typed_pattern_from_var v None)) + back_vars + in + let gids = + List.map + (fun (g : T.region_var_group) -> g.id) + call.regions_hierarchy + in + let back_vars = + List.map (Option.map mk_texpression_from_var) back_vars + in + let back_funs_map = + RegionGroupId.Map.of_list (List.combine gids back_vars) + in + (ctx, dsg.fwd_info.ignore_output, Some back_funs_map, back_funs) in (* Compute the pattern for the destination *) let ctx, dest = fresh_var_for_symbolic_value call.dest ctx in @@ -2407,19 +2375,6 @@ and translate_end_abstraction_fun_call (ectx : C.eval_ctx) (abs : V.abs) raise (Failure "Unreachable") in let effect_info = get_fun_effect_info ctx fun_id None (Some rg_id) in - let generics = ctx_translate_fwd_generic_args ctx call.generics in - (* Retrieve the original call and the parent abstractions *) - let _forward, backwards = get_abs_ancestors ctx abs call_id in - (* Retrieve the values consumed when we called the forward function and - * ended the parent backward functions: those give us part of the input - * values (rem: for now, as we disallow nested lifetimes, there can't be - * parent backward functions). - * Note that the forward inputs **include the fuel and the input state** - * (if we use those). *) - let fwd_inputs = call_info.forward_inputs in - let back_ancestors_inputs = - List.concat (List.map (fun (_abs, args) -> args) backwards) - in (* Retrieve the values consumed upon ending the loans inside this * abstraction: those give us the remaining input values *) let back_inputs = abs_to_consumed ctx ectx abs in @@ -2434,11 +2389,6 @@ and translate_end_abstraction_fun_call (ectx : C.eval_ctx) (abs : V.abs) ([ back_state ], ctx, Some nstate) else ([], ctx, None) in - (* Concatenate all the inpus *) - let inherited_inputs = - if !Config.return_back_funs then [] - else List.concat [ fwd_inputs; back_ancestors_inputs ] - in let back_inputs = List.append back_inputs back_state in (* Retrieve the values given back by this function: those are the output * values. We rely on the fact that there are no nested borrows to use the @@ -2459,58 +2409,33 @@ and translate_end_abstraction_fun_call (ectx : C.eval_ctx) (abs : V.abs) (* Retrieve the function id, and register the function call in the context if necessary.Arith_status *) let ctx, func = - bs_ctx_register_backward_call abs effect_info call_id rg_id inherited_inputs - back_inputs generics output.ty ctx + bs_ctx_register_backward_call abs call_id rg_id back_inputs ctx in (* Translate the next expression *) let next_e = translate_expression e ctx in (* Put everything together *) - let inputs = List.append inherited_inputs back_inputs in + let inputs = back_inputs in let args_mplaces = List.map (fun _ -> None) inputs in let args = List.map (fun (arg, mp) -> mk_opt_mplace_texpression mp arg) (List.combine inputs args_mplaces) in - (* **Optimization**: - ================= - We do a small optimization here if we split the forward/backward functions. - If the backward function doesn't have any output, we don't introduce any function - call. - See the comment in {!Config.filter_useless_monadic_calls}. - - TODO: use an option to disallow backward functions from updating the state. - TODO: a backward function which only gives back shared borrows shouldn't - update the state (state updates should only be used for mutable borrows, - with objects like Rc for instance). - *) - if - (not !Config.return_back_funs) - && !Config.filter_useless_monadic_calls - && outputs = [] && nstate = None - then ( - (* No outputs - we do a small sanity check: the backward function - should have exactly the same number of inputs as the forward: - this number can be different only if the forward function returned - a value containing mutable borrows, which can't be the case... *) - assert (List.length inputs = List.length fwd_inputs); - next_e) - else - (* The backward function might also have been filtered if we do not - split the forward/backward functions *) - match func with - | None -> next_e - | Some func -> - log#ldebug - (lazy - (let args = List.map (texpression_to_string ctx) args in - "func: " - ^ texpression_to_string ctx func - ^ "\nfunc type: " - ^ pure_ty_to_string ctx func.ty - ^ "\n\nargs:\n" ^ String.concat "\n" args)); - let call = mk_apps func args in - mk_let effect_info.can_fail output call next_e + (* The backward function might have been filtered it does nothing + (consumes unit and returns unit). *) + match func with + | None -> next_e + | Some func -> + log#ldebug + (lazy + (let args = List.map (texpression_to_string ctx) args in + "func: " + ^ texpression_to_string ctx func + ^ "\nfunc type: " + ^ pure_ty_to_string ctx func.ty + ^ "\n\nargs:\n" ^ String.concat "\n" args)); + let call = mk_apps func args in + mk_let effect_info.can_fail output call next_e and translate_end_abstraction_identity (ectx : C.eval_ctx) (abs : V.abs) (e : S.expression) (ctx : bs_ctx) : texpression = @@ -2637,10 +2562,7 @@ and translate_end_abstraction_loop (ectx : C.eval_ctx) (abs : V.abs) else ([], ctx, None) in (* Concatenate all the inputs *) - let inputs = - if !Config.return_back_funs then List.concat [ back_inputs; back_state ] - else List.concat [ fwd_inputs; back_inputs; back_state ] - in + let inputs = List.concat [ back_inputs; back_state ] in (* Retrieve the values given back by this function *) let ctx, outputs = abs_to_given_back None abs ctx in (* Group the output values together: first the updated inputs *) @@ -2670,77 +2592,46 @@ and translate_end_abstraction_loop (ectx : C.eval_ctx) (abs : V.abs) - or a call to the variable we introduced for the backward function, if we merge the forward/backward functions *) let func = - if !Config.return_back_funs then - RegionGroupId.Map.find rg_id (Option.get loop_info.back_funs) - else - let func_ty = mk_arrows input_tys ret_ty in - let func = Fun (FromLlbc (FunId fun_id, Some loop_id, Some rg_id)) in - let func = { id = FunOrOp func; generics } in - Some { e = Qualif func; ty = func_ty } + RegionGroupId.Map.find rg_id (Option.get loop_info.back_funs) in - (* **Optimization**: - ================= - We do a small optimization here in case we split the forward/backward - functions. - If the backward function doesn't have any output, we don't introduce - any function call. - See the comment in {!Config.filter_useless_monadic_calls}. - - TODO: use an option to disallow backward functions from updating the state. - TODO: a backward function which only gives back shared borrows shouldn't - update the state (state updates should only be used for mutable borrows, - with objects like Rc for instance). - *) - if - (not !Config.return_back_funs) - && !Config.filter_useless_monadic_calls - && outputs = [] && nstate = None - then ( - (* No outputs - we do a small sanity check: the backward function - should have exactly the same number of inputs as the forward: - this number can be different only if the forward function returned - a value containing mutable borrows, which can't be the case... *) - assert (List.length inputs = List.length fwd_inputs); - next_e) - else - (* In case we merge the fwd/back functions we filter the backward - functions elsewhere *) - match func with - | None -> next_e - | Some func -> - let call = mk_apps func args in - (* Add meta-information - this is slightly hacky: we look at the - values consumed by the abstraction (note that those come from - *before* we applied the fixed-point context) and use them to - guide the naming of the output vars. - - Also, we need to convert the backward outputs from patterns to - variables. - - Finally, in practice, this works well only for loop bodies: - we do this only in this case. - TODO: improve the heuristics, to give weight to the hints for - instance. - *) - let next_e = - if ctx.inside_loop then - let consumed_values = abs_to_consumed ctx ectx abs in - let var_values = List.combine outputs consumed_values in - let var_values = - List.filter_map - (fun (var, v) -> - match var.Pure.value with - | PatVar (var, _) -> Some (var, v) - | _ -> None) - var_values - in - let vars, values = List.split var_values in - mk_emeta_symbolic_assignments vars values next_e - else next_e - in + (* We may have filtered the backward function elsewhere if it doesn't + do anything (doesn't consume anything and doesn't return anything) *) + match func with + | None -> next_e + | Some func -> + let call = mk_apps func args in + (* Add meta-information - this is slightly hacky: we look at the + values consumed by the abstraction (note that those come from + *before* we applied the fixed-point context) and use them to + guide the naming of the output vars. + + Also, we need to convert the backward outputs from patterns to + variables. + + Finally, in practice, this works well only for loop bodies: + we do this only in this case. + TODO: improve the heuristics, to give weight to the hints for + instance. + *) + let next_e = + if ctx.inside_loop then + let consumed_values = abs_to_consumed ctx ectx abs in + let var_values = List.combine outputs consumed_values in + let var_values = + List.filter_map + (fun (var, v) -> + match var.Pure.value with + | PatVar (var, _) -> Some (var, v) + | _ -> None) + var_values + in + let vars, values = List.split var_values in + mk_emeta_symbolic_assignments vars values next_e + else next_e + in - (* Create the let-binding *) - mk_let effect_info.can_fail output call next_e) + (* Create the let-binding *) + mk_let effect_info.can_fail output call next_e) and translate_global_eval (gid : A.GlobalDeclId.id) (sval : V.symbolic_value) (e : S.expression) (ctx : bs_ctx) : texpression = @@ -3068,48 +2959,40 @@ and translate_forward_end (ectx : C.eval_ctx) *) let ctx = (* Introduce variables for the inputs and the state variable - and update the context. *) - if !Config.return_back_funs then - (* If the forward/backward functions are not split, we need - to introduce fresh variables for the additional inputs, - because they are locally introduced in a lambda *) - let back_sg = RegionGroupId.Map.find bid ctx.sg.back_sg in - let ctx, backward_inputs_no_state = - fresh_vars back_sg.inputs_no_state ctx - in - let ctx, backward_inputs_with_state = - if back_sg.effect_info.stateful then - let ctx, var, _ = bs_ctx_fresh_state_var ctx in - (ctx, backward_inputs_no_state @ [ var ]) - else (ctx, backward_inputs_no_state) - in - { - ctx with - backward_inputs_no_state = - RegionGroupId.Map.add bid backward_inputs_no_state - ctx.backward_inputs_no_state; - backward_inputs_with_state = - RegionGroupId.Map.add bid backward_inputs_with_state - ctx.backward_inputs_with_state; - } - else - (* Update the state variable *) - let back_state_var = - RegionGroupId.Map.find bid ctx.back_state_vars - in - { ctx with state_var = back_state_var } + and update the context. + + We need to introduce fresh variables for the additional inputs, + because they are locally introduced in a lambda. + *) + let back_sg = RegionGroupId.Map.find bid ctx.sg.back_sg in + let ctx, backward_inputs_no_state = + fresh_vars back_sg.inputs_no_state ctx + in + let ctx, backward_inputs_with_state = + if back_sg.effect_info.stateful then + let ctx, var, _ = bs_ctx_fresh_state_var ctx in + (ctx, backward_inputs_no_state @ [ var ]) + else (ctx, backward_inputs_no_state) + in + { + ctx with + backward_inputs_no_state = + RegionGroupId.Map.add bid backward_inputs_no_state + ctx.backward_inputs_no_state; + backward_inputs_with_state = + RegionGroupId.Map.add bid backward_inputs_with_state + ctx.backward_inputs_with_state; + } in let e = T.RegionGroupId.Map.find bid back_e in let finish e = (* Wrap in lambdas if necessary *) - if !Config.return_back_funs then - let inputs = - RegionGroupId.Map.find bid ctx.backward_inputs_with_state - in - let places = List.map (fun _ -> None) inputs in - mk_lambdas_from_vars inputs places e - else e + let inputs = + RegionGroupId.Map.find bid ctx.backward_inputs_with_state + in + let places = List.map (fun _ -> None) inputs in + mk_lambdas_from_vars inputs places e in (ctx, e, finish) in @@ -3131,85 +3014,83 @@ and translate_forward_end (ectx : C.eval_ctx) function, if needs be, and lookup the proper expression. *) let translate_end ctx = - if !Config.return_back_funs then - (* Compute the output of the forward function *) - let fwd_effect_info = ctx.sg.fwd_info.effect_info in - let ctx, pure_fwd_var = fresh_var None ctx.sg.fwd_output ctx in - let fwd_e = translate_one_end ctx None in - - (* Introduce the backward functions. *) - let back_el = - List.map - (fun ((gid, _) : RegionGroupId.id * back_sg_info) -> - translate_one_end ctx (Some gid)) - (RegionGroupId.Map.bindings ctx.sg.back_sg) - in + (* Compute the output of the forward function *) + let fwd_effect_info = ctx.sg.fwd_info.effect_info in + let ctx, pure_fwd_var = fresh_var None ctx.sg.fwd_output ctx in + let fwd_e = translate_one_end ctx None in - (* Compute whether the backward expressions should be evaluated straight - away or not (i.e., if we should bind them with monadic let-bindings - or not). We evaluate them straight away if they can fail and have no - inputs. *) - let evaluate_backs = - List.map - (fun (sg : back_sg_info) -> - if !Config.simplify_merged_fwd_backs then - sg.inputs = [] && sg.effect_info.can_fail - else false) - (RegionGroupId.Map.values ctx.sg.back_sg) - in + (* Introduce the backward functions. *) + let back_el = + List.map + (fun ((gid, _) : RegionGroupId.id * back_sg_info) -> + translate_one_end ctx (Some gid)) + (RegionGroupId.Map.bindings ctx.sg.back_sg) + in - (* Introduce variables for the backward functions. - We lookup the LLBC definition in an attempt to derive pretty names - for those functions. *) - let _, back_vars = fresh_back_vars_for_current_fun ctx in + (* Compute whether the backward expressions should be evaluated straight + away or not (i.e., if we should bind them with monadic let-bindings + or not). We evaluate them straight away if they can fail and have no + inputs. *) + let evaluate_backs = + List.map + (fun (sg : back_sg_info) -> + if !Config.simplify_merged_fwd_backs then + sg.inputs = [] && sg.effect_info.can_fail + else false) + (RegionGroupId.Map.values ctx.sg.back_sg) + in - (* Create the return expressions *) - let vars = - let back_vars = List.filter_map (fun x -> x) back_vars in - if ctx.sg.fwd_info.ignore_output then back_vars - else pure_fwd_var :: back_vars - in - let vars = List.map mk_texpression_from_var vars in - let ret = mk_simpl_tuple_texpression vars in - - (* Introduce a fresh input state variable for the forward expression *) - let _ctx, state_var, state_pat = - if fwd_effect_info.stateful then - let ctx, var, pat = bs_ctx_fresh_state_var ctx in - (ctx, [ var ], [ pat ]) - else (ctx, [], []) - in + (* Introduce variables for the backward functions. + We lookup the LLBC definition in an attempt to derive pretty names + for those functions. *) + let _, back_vars = fresh_back_vars_for_current_fun ctx in + + (* Create the return expressions *) + let vars = + let back_vars = List.filter_map (fun x -> x) back_vars in + if ctx.sg.fwd_info.ignore_output then back_vars + else pure_fwd_var :: back_vars + in + let vars = List.map mk_texpression_from_var vars in + let ret = mk_simpl_tuple_texpression vars in + + (* Introduce a fresh input state variable for the forward expression *) + let _ctx, state_var, state_pat = + if fwd_effect_info.stateful then + let ctx, var, pat = bs_ctx_fresh_state_var ctx in + (ctx, [ var ], [ pat ]) + else (ctx, [], []) + in - let state_var = List.map mk_texpression_from_var state_var in - let ret = mk_simpl_tuple_texpression (state_var @ [ ret ]) in - let ret = mk_result_return_texpression ret in + let state_var = List.map mk_texpression_from_var state_var in + let ret = mk_simpl_tuple_texpression (state_var @ [ ret ]) in + let ret = mk_result_return_texpression ret in - (* Introduce all the let-bindings *) + (* Introduce all the let-bindings *) - (* Combine: - - the backward variables - - whether we should evaluate the expression for the backward function - (i.e., should we use a monadic let-binding or not - we do if the - backward functions don't have inputs and can fail) - - the expressions for the backward functions - *) - let back_vars_els = - List.filter_map - (fun (v, (eval, el)) -> - match v with None -> None | Some v -> Some (v, eval, el)) - (List.combine back_vars (List.combine evaluate_backs back_el)) - in - let e = - List.fold_right - (fun (var, evaluate, back_e) e -> - mk_let evaluate (mk_typed_pattern_from_var var None) back_e e) - back_vars_els ret - in - (* Bind the expression for the forward output *) - let fwd_var = mk_typed_pattern_from_var pure_fwd_var None in - let pat = mk_simpl_tuple_pattern (state_pat @ [ fwd_var ]) in - mk_let fwd_effect_info.can_fail pat fwd_e e - else translate_one_end ctx ctx.bid + (* Combine: + - the backward variables + - whether we should evaluate the expression for the backward function + (i.e., should we use a monadic let-binding or not - we do if the + backward functions don't have inputs and can fail) + - the expressions for the backward functions + *) + let back_vars_els = + List.filter_map + (fun (v, (eval, el)) -> + match v with None -> None | Some v -> Some (v, eval, el)) + (List.combine back_vars (List.combine evaluate_backs back_el)) + in + let e = + List.fold_right + (fun (var, evaluate, back_e) e -> + mk_let evaluate (mk_typed_pattern_from_var var None) back_e e) + back_vars_els ret + in + (* Bind the expression for the forward output *) + let fwd_var = mk_typed_pattern_from_var pure_fwd_var None in + let pat = mk_simpl_tuple_pattern (state_pat @ [ fwd_var ]) in + mk_let fwd_effect_info.can_fail pat fwd_e e in (* If we are (re-)entering a loop, we need to introduce a call to the @@ -3279,24 +3160,22 @@ and translate_forward_end (ectx : C.eval_ctx) backward functions of the outer function. *) let ctx, back_funs_map, back_funs = - if !Config.return_back_funs then - let ctx, back_vars = fresh_back_vars_for_current_fun ctx in - let back_funs = - List.filter_map - (fun v -> - match v with - | None -> None - | Some v -> Some (mk_typed_pattern_from_var v None)) - back_vars - in - let gids = RegionGroupId.Map.keys ctx.sg.back_sg in - let back_funs_map = - RegionGroupId.Map.of_list - (List.combine gids - (List.map (Option.map mk_texpression_from_var) back_vars)) - in - (ctx, Some back_funs_map, back_funs) - else (ctx, None, []) + let ctx, back_vars = fresh_back_vars_for_current_fun ctx in + let back_funs = + List.filter_map + (fun v -> + match v with + | None -> None + | Some v -> Some (mk_typed_pattern_from_var v None)) + back_vars + in + let gids = RegionGroupId.Map.keys ctx.sg.back_sg in + let back_funs_map = + RegionGroupId.Map.of_list + (List.combine gids + (List.map (Option.map mk_texpression_from_var) back_vars)) + in + (ctx, Some back_funs_map, back_funs) in (* Introduce patterns *) @@ -3438,91 +3317,58 @@ and translate_loop (loop : S.loop) (ctx : bs_ctx) : texpression = (* The output type of the loop function *) let fwd_effect_info = { ctx.sg.fwd_info.effect_info with is_rec = true } in let back_effect_infos, output_ty = - if !Config.return_back_funs then - (* The loop backward functions consume the same additional inputs as the parent - function, but have custom outputs *) - let back_sgs = RegionGroupId.Map.bindings ctx.sg.back_sg in - let given_back_tys = RegionGroupId.Map.values rg_to_given_back_tys in - let back_info_tys = - List.map - (fun (((id, back_sg), given_back) : (_ * back_sg_info) * ty list) -> - (* Remark: the effect info of the backward function for the loop - is almost the same as for the backward function of the parent function. - Quite importantly, the fact that the function is stateful and/or can fail - mostly depends on whether it has inputs or not, and the backward functions - for the loops have the same inputs as the backward functions for the parent - function. - *) - let effect_info = back_sg.effect_info in - let effect_info = { effect_info with is_rec = true } in - (* Compute the input/output types *) - let inputs = List.map snd back_sg.inputs in - let outputs = given_back in - (* Filter if necessary *) - let ty = - if - !Config.simplify_merged_fwd_backs && inputs = [] && outputs = [] - then None - else - let output = mk_simpl_tuple_ty outputs in - let output = - mk_back_output_ty_from_effect_info effect_info inputs output - in - let ty = mk_arrows inputs output in - Some ty - in - ((id, effect_info), ty)) - (List.combine back_sgs given_back_tys) - in - let back_info = List.map fst back_info_tys in - let back_info = RegionGroupId.Map.of_list back_info in - let back_tys = List.filter_map snd back_info_tys in - let output = - if ctx.sg.fwd_info.ignore_output then back_tys - else ctx.sg.fwd_output :: back_tys - in - let output = mk_simpl_tuple_ty output in - let effect_info = ctx.sg.fwd_info.effect_info in - let output = - if effect_info.stateful then mk_simpl_tuple_ty [ mk_state_ty; output ] - else output - in - let output = - if effect_info.can_fail && inputs <> [] then mk_result_ty output - else output - in - (back_info, output) - else - let back_info = - RegionGroupId.Map.of_list - (List.map - (fun ((id, back_sg) : _ * back_sg_info) -> - (id, { back_sg.effect_info with is_rec = true })) - (RegionGroupId.Map.bindings ctx.sg.back_sg)) - in - let output = - match ctx.bid with - | None -> - (* Forward function: same type as the parent function *) - (translate_fun_sig_from_decomposed ctx.sg None).output - | Some rg_id -> - (* Backward function: custom return type *) - let doutputs = - T.RegionGroupId.Map.find rg_id rg_to_given_back_tys - in - let output = mk_simpl_tuple_ty doutputs in - let fwd_effect_info = ctx.sg.fwd_info.effect_info in - let output = - if fwd_effect_info.stateful then - mk_simpl_tuple_ty [ mk_state_ty; output ] - else output - in - let output = - if fwd_effect_info.can_fail then mk_result_ty output else output - in - output - in - (back_info, output) + (* The loop backward functions consume the same additional inputs as the parent + function, but have custom outputs *) + let back_sgs = RegionGroupId.Map.bindings ctx.sg.back_sg in + let given_back_tys = RegionGroupId.Map.values rg_to_given_back_tys in + let back_info_tys = + List.map + (fun (((id, back_sg), given_back) : (_ * back_sg_info) * ty list) -> + (* Remark: the effect info of the backward function for the loop + is almost the same as for the backward function of the parent function. + Quite importantly, the fact that the function is stateful and/or can fail + mostly depends on whether it has inputs or not, and the backward functions + for the loops have the same inputs as the backward functions for the parent + function. + *) + let effect_info = back_sg.effect_info in + let effect_info = { effect_info with is_rec = true } in + (* Compute the input/output types *) + let inputs = List.map snd back_sg.inputs in + let outputs = given_back in + (* Filter if necessary *) + let ty = + if !Config.simplify_merged_fwd_backs && inputs = [] && outputs = [] + then None + else + let output = mk_simpl_tuple_ty outputs in + let output = + mk_back_output_ty_from_effect_info effect_info inputs output + in + let ty = mk_arrows inputs output in + Some ty + in + ((id, effect_info), ty)) + (List.combine back_sgs given_back_tys) + in + let back_info = List.map fst back_info_tys in + let back_info = RegionGroupId.Map.of_list back_info in + let back_tys = List.filter_map snd back_info_tys in + let output = + if ctx.sg.fwd_info.ignore_output then back_tys + else ctx.sg.fwd_output :: back_tys + in + let output = mk_simpl_tuple_ty output in + let effect_info = ctx.sg.fwd_info.effect_info in + let output = + if effect_info.stateful then mk_simpl_tuple_ty [ mk_state_ty; output ] + else output + in + let output = + if effect_info.can_fail && inputs <> [] then mk_result_ty output + else output + in + (back_info, output) in (* Add the loop information in the context *) @@ -3708,21 +3554,19 @@ let wrap_in_match_fuel (fuel0 : VarId.id) (fuel : VarId.id) (body : texpression) let translate_fun_decl (ctx : bs_ctx) (body : S.expression option) : fun_decl = (* Translate *) let def = ctx.fun_decl in - let bid = ctx.bid in + assert (ctx.bid = None); log#ldebug (lazy ("SymbolicToPure.translate_fun_decl: " ^ name_to_string ctx def.name - ^ " (" - ^ Print.option_to_string T.RegionGroupId.to_string bid - ^ ")\n")); + ^ "\n")); (* Translate the declaration *) let def_id = def.def_id in let llbc_name = def.name in let name = name_to_string ctx llbc_name in (* Translate the signature *) - let signature = translate_fun_sig_from_decomposed ctx.sg ctx.bid in + let signature = translate_fun_sig_from_decomposed ctx.sg in let regions_hierarchy = FunIdMap.find (FRegular def_id) ctx.fun_ctx.regions_hierarchies in @@ -3732,7 +3576,7 @@ let translate_fun_decl (ctx : bs_ctx) (body : S.expression option) : fun_decl = | None -> None | Some body -> let effect_info = - get_fun_effect_info ctx (FunId (FRegular def_id)) None bid + get_fun_effect_info ctx (FunId (FRegular def_id)) None None in let body = translate_expression body ctx in (* Add a match over the fuel, if necessary *) @@ -3760,37 +3604,8 @@ let translate_fun_decl (ctx : bs_ctx) (body : S.expression option) : fun_decl = if effect_info.stateful_group then [ mk_state_var ctx.state_var ] else [] in - (* Compute the list of (properly ordered) backward input variables *) - let backward_inputs : var list = - match bid with - | None -> [] - | Some back_id -> - assert (not !Config.return_back_funs); - let parents_ids = - list_ordered_ancestor_region_groups regions_hierarchy back_id - in - let backward_ids = List.append parents_ids [ back_id ] in - List.concat - (List.map - (fun id -> - T.RegionGroupId.Map.find id ctx.backward_inputs_no_state) - backward_ids) - in - (* Introduce the backward input state (the state at call site of the - * *backward* function), if necessary *) - let back_state = - if effect_info.stateful && Option.is_some bid then - let state_var = - RegionGroupId.Map.find (Option.get bid) ctx.back_state_vars - in - [ mk_state_var state_var ] - else [] - in (* Group the inputs together *) - let inputs = - List.concat - [ fuel; ctx.forward_inputs; fwd_state; backward_inputs; back_state ] - in + let inputs = List.concat [ fuel; ctx.forward_inputs; fwd_state ] in let inputs_lvs = List.map (fun v -> mk_typed_pattern_from_var v None) inputs in @@ -3799,16 +3614,10 @@ let translate_fun_decl (ctx : bs_ctx) (body : S.expression option) : fun_decl = (lazy ("SymbolicToPure.translate_fun_decl: " ^ name_to_string ctx def.name - ^ " (" - ^ Print.option_to_string T.RegionGroupId.to_string bid - ^ ")" ^ "\n- forward_inputs: " + ^ "\n- inputs: " ^ String.concat ", " (List.map show_var ctx.forward_inputs) - ^ "\n- fwd_state: " + ^ "\n- state: " ^ String.concat ", " (List.map show_var fwd_state) - ^ "\n- backward_inputs: " - ^ String.concat ", " (List.map show_var backward_inputs) - ^ "\n- back_state: " - ^ String.concat ", " (List.map show_var back_state) ^ "\n- signature.inputs: " ^ String.concat ", " (List.map (pure_ty_to_string ctx) signature.inputs))); @@ -3837,7 +3646,6 @@ let translate_fun_decl (ctx : bs_ctx) (body : S.expression option) : fun_decl = kind = def.kind; num_loops; loop_id; - back_id = bid; llbc_name; name; signature; -- cgit v1.2.3 From fe2a2cb34148e46e32cdcfbf100e38d9986082cd Mon Sep 17 00:00:00 2001 From: Son Ho Date: Fri, 8 Mar 2024 16:06:35 +0100 Subject: Make progress on propagating the changes --- compiler/SymbolicToPure.ml | 13 ++----------- 1 file changed, 2 insertions(+), 11 deletions(-) (limited to 'compiler/SymbolicToPure.ml') diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml index 859d6f17..2db5f66c 100644 --- a/compiler/SymbolicToPure.ml +++ b/compiler/SymbolicToPure.ml @@ -2035,7 +2035,7 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) : | S.Fun (fid, call_id) -> (* Regular function call *) let fid_t = translate_fun_id_or_trait_method_ref ctx fid in - let func = Fun (FromLlbc (fid_t, None, None)) in + let func = Fun (FromLlbc (fid_t, None)) in (* Retrieve the effect information about this function (can fail, * takes a state as input, etc.) *) let effect_info = get_fun_effect_info ctx fid None None in @@ -2539,8 +2539,6 @@ and translate_end_abstraction_loop (ectx : C.eval_ctx) (abs : V.abs) get_fun_effect_info ctx (FunId fun_id) (Some vloop_id) (Some rg_id) in let loop_info = LoopId.Map.find loop_id ctx.loops in - let generics = loop_info.generics in - let fwd_inputs = Option.get loop_info.forward_inputs in (* Retrieve the additional backward inputs. Note that those are actually the backward inputs of the function we are synthesizing (and that we need to *transmit* to the loop backward function): they are not the @@ -2582,10 +2580,6 @@ and translate_end_abstraction_loop (ectx : C.eval_ctx) (abs : V.abs) (fun (arg, mp) -> mk_opt_mplace_texpression mp arg) (List.combine inputs args_mplaces) in - let input_tys = (List.map (fun (x : texpression) -> x.ty)) args in - let ret_ty = - if effect_info.can_fail then mk_result_ty output.ty else output.ty - in (* Create the expression for the function: - it is either a call to a top-level function, if we split the forward/backward functions @@ -3218,7 +3212,7 @@ and translate_forward_end (ectx : C.eval_ctx) let out_pat = mk_simpl_tuple_pattern out_pats in let loop_call = - let fun_id = Fun (FromLlbc (FunId fid, Some loop_id, None)) in + let fun_id = Fun (FromLlbc (FunId fid, Some loop_id)) in let func = { id = FunOrOp fun_id; generics = loop_info.generics } in let input_tys = (List.map (fun (x : texpression) -> x.ty)) args in let ret_ty = @@ -3567,9 +3561,6 @@ let translate_fun_decl (ctx : bs_ctx) (body : S.expression option) : fun_decl = let name = name_to_string ctx llbc_name in (* Translate the signature *) let signature = translate_fun_sig_from_decomposed ctx.sg in - let regions_hierarchy = - FunIdMap.find (FRegular def_id) ctx.fun_ctx.regions_hierarchies - in (* Translate the body, if there is *) let body = match body with -- cgit v1.2.3