diff options
Diffstat (limited to 'compiler/SymbolicToPure.ml')
-rw-r--r-- | compiler/SymbolicToPure.ml | 383 |
1 files changed, 241 insertions, 142 deletions
diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml index 6d01614d..9d249cfb 100644 --- a/compiler/SymbolicToPure.ml +++ b/compiler/SymbolicToPure.ml @@ -12,6 +12,10 @@ module FA = FunsAnalysis (** The local logger *) let log = L.symbolic_to_pure_log +(* TODO: carrying configurations everywhere is super annoying. + Group everything in references in a [Config.ml] file (put aside the execution + mode, maybe). +*) type config = { filter_useless_back_calls : bool; (** If [true], filter the useless calls to backward functions. @@ -39,6 +43,12 @@ type config = { Note that we later filter the useless *forward* calls in the micro-passes, where it is more natural to do. *) + backward_no_state_update : bool; + (** Controls whether backward functions update the state, in case we use + a state ({!use_state}). + + See {!Translate.config.backward_no_state_update}. + *) } type type_context = { @@ -110,7 +120,23 @@ type bs_ctx = { *) var_counter : VarId.generator; state_var : VarId.id; - (** The current state variable, in case we use a state *) + (** The current state variable, in case the function is stateful *) + back_state_var : VarId.id; + (** The additional input state variable received by a stateful backward function. + When generating stateful functions, we generate code of the following + form: + + {[ + (st1, y) <-- f_fwd x st0; // st0 is the state upon calling f_fwd + ... // the state may be updated + (st3, x') <-- f_back x st0 y' st2; // st2 is the state upon calling f_back + ]} + + When translating a backward function, we need at some point to update + [state_var] with [back_state_var], to account for the fact that the + state may have been updated by the caller between the call to the + forward function and the call to the backward function. + *) forward_inputs : var list; (** The input parameters for the forward function *) backward_inputs : var list T.RegionGroupId.Map.t; @@ -498,20 +524,26 @@ let list_ancestor_abstractions (ctx : bs_ctx) (abs : V.abs) : let abs_ids = list_ancestor_abstractions_ids ctx abs in List.map (fun id -> V.AbstractionId.Map.find id ctx.abstractions) abs_ids -(** Small utility. *) -let get_fun_effect_info (fun_infos : FA.fun_info A.FunDeclId.Map.t) - (fun_id : A.fun_id) (gid : T.RegionGroupId.id option) : fun_effect_info = +(** Small utility. + + [backward_no_state_update]: see {!config} + *) +let get_fun_effect_info (backward_no_state_update : bool) + (fun_infos : FA.fun_info A.FunDeclId.Map.t) (fun_id : A.fun_id) + (gid : T.RegionGroupId.id option) : fun_effect_info = match fun_id with | A.Regular fid -> let info = A.FunDeclId.Map.find fid fun_infos in - let input_state = info.stateful in - let output_state = input_state && gid = None in - { can_fail = info.can_fail; input_state; output_state } + let stateful_group = info.stateful in + let stateful = + stateful_group && ((not backward_no_state_update) || gid = None) + in + { can_fail = info.can_fail; stateful_group; stateful } | A.Assumed aid -> { can_fail = Assumed.assumed_can_fail aid; - input_state = false; - output_state = false; + stateful_group = false; + stateful = false; } (** Translate a function signature. @@ -519,10 +551,11 @@ let get_fun_effect_info (fun_infos : FA.fun_info A.FunDeclId.Map.t) Note that the function also takes a list of names for the inputs, and computes, for every output for the backward functions, a corresponding name (outputs for backward functions come from borrows in the inputs - of the forward function). + of the forward function) which we use as hints to generate pretty names. *) -let translate_fun_sig (fun_infos : FA.fun_info A.FunDeclId.Map.t) - (fun_id : A.fun_id) (types_infos : TA.type_infos) (sg : A.fun_sig) +let translate_fun_sig (backward_no_state_update : bool) + (fun_infos : FA.fun_info A.FunDeclId.Map.t) (fun_id : A.fun_id) + (types_infos : TA.type_infos) (sg : A.fun_sig) (input_names : string option list) (bid : T.RegionGroupId.id option) : fun_sig_named_outputs = (* Retrieve the list of parent backward functions *) @@ -572,17 +605,42 @@ let translate_fun_sig (fun_infos : FA.fun_info A.FunDeclId.Map.t) *) List.filter_map (translate_back_ty_for_gid gid) [ sg.output ] in - (* Does the function take a state as input, does it return a state and can - * it fail? *) - let effect_info = get_fun_effect_info fun_infos fun_id bid in - (* *) - let state_ty = if effect_info.input_state then [ mk_state_ty ] else [] in + (* Is the function stateful, and can it fail? *) + let effect_info = + get_fun_effect_info backward_no_state_update fun_infos fun_id bid + in + (* If the function is stateful, the inputs are: + - forward: [fwd_ty0, ..., fwd_tyn, state] + - backward: + - if config.no_backward_state: [fwd_ty0, ..., fwd_tyn, state, back_ty, state] + - otherwise: [fwd_ty0, ..., fwd_tyn, state, back_ty] + + The backward takes the same state as input as the forward function, + together with the state at the point where it gets called, if it is + stateful. + + See the comments for {!Translate.config.backward_no_state_update} + *) + let fwd_state_ty = + (* For the forward state, we check if the *whole group* is stateful. + See {!effect_info}. *) + if effect_info.stateful_group then [ mk_state_ty ] else [] + in + let back_state_ty = + (* For the backward state, we check if the function is a backward function, + and it is stateful *) + if effect_info.stateful && Option.is_some gid then [ mk_state_ty ] else [] + in + (* Concatenate the inputs, in the following order: * - forward inputs - * - state input + * - forward state input * - backward inputs + * - backward state input *) - let inputs = List.concat [ fwd_inputs; state_ty; back_inputs ] in + let inputs = + List.concat [ fwd_inputs; fwd_state_ty; back_inputs; back_state_ty ] + in (* Outputs *) let output_names, doutputs = match gid with @@ -620,7 +678,7 @@ let translate_fun_sig (fun_infos : FA.fun_info A.FunDeclId.Map.t) let output = mk_simpl_tuple_ty doutputs in (* Add the output state *) let output = - if effect_info.output_state then mk_simpl_tuple_ty [ mk_state_ty; output ] + if effect_info.stateful then mk_simpl_tuple_ty [ mk_state_ty; output ] else output in (* Wrap in a result type *) @@ -1087,6 +1145,15 @@ let rec translate_expression (config : config) (e : S.expression) (ctx : bs_ctx) | Assertion (v, e) -> translate_assertion config v e ctx | Expansion (p, sv, exp) -> translate_expansion config p sv exp ctx | Meta (meta, e) -> translate_meta config meta e ctx + | ForwardEnd e -> + (* Update the current state with the additional state received by the backward + function, if needs be *) + let ctx = + match ctx.bid with + | None -> ctx + | Some _ -> { ctx with state_var = ctx.back_state_var } + in + translate_expression config e ctx and translate_panic (ctx : bs_ctx) : texpression = (* Here we use the function return type - note that it is ok because @@ -1095,13 +1162,15 @@ and translate_panic (ctx : bs_ctx) : texpression = (* If we use a state monad, we need to add a lambda for the state variable *) (* Note that only forward functions return a state *) let output_ty = mk_simpl_tuple_ty ctx.sg.doutputs in - if ctx.sg.info.effect_info.output_state then + (* TODO: we should use a [Fail] function *) + if ctx.sg.info.effect_info.stateful then (* Create the [Fail] value *) let ret_ty = mk_simpl_tuple_ty [ mk_state_ty; output_ty ] in let ret_v = mk_result_fail_texpression ret_ty in ret_v else mk_result_fail_texpression output_ty +(** [opt_v]: the value to return, in case we translate a forward function *) and translate_return (opt_v : V.typed_value option) (ctx : bs_ctx) : texpression = (* There are two cases: @@ -1109,44 +1178,40 @@ and translate_return (opt_v : V.typed_value option) (ctx : bs_ctx) : texpression value should be [Some] (it is the returned value) - or we are translating a backward function, in which case it should be [None] *) - match ctx.bid with - | None -> - (* Forward function *) - let v = Option.get opt_v in - let v = typed_value_to_texpression ctx v in - (* We may need to return a state - * - error-monad: Return x - * - state-error: Return (state, x) - * *) - if ctx.sg.info.effect_info.output_state then - let state_var = - { - id = ctx.state_var; - basename = Some ConstStrings.state_basename; - ty = mk_state_ty; - } + (* Compute the values that we should return *without the state and the result + * wrapper* *) + let output = + match ctx.bid with + | None -> + (* Forward function *) + let v = Option.get opt_v in + typed_value_to_texpression ctx v + | Some bid -> + (* Backward function *) + (* Sanity check *) + assert (opt_v = None); + (* Group the variables in which we stored the values we need to give back. + * See the explanations for the [SynthInput] case in [translate_end_abstraction] *) + let backward_outputs = + T.RegionGroupId.Map.find bid ctx.backward_outputs in - let state_rvalue = mk_texpression_from_var state_var in - mk_result_return_texpression - (mk_simpl_tuple_texpression [ state_rvalue; v ]) - else mk_result_return_texpression v - | Some bid -> - (* Backward function *) - (* Sanity check *) - assert (opt_v = None); - assert (not ctx.sg.info.effect_info.output_state); - (* We simply need to return the variables in which we stored the values - * we need to give back. - * See the explanations for the [SynthInput] case in [translate_end_abstraction] *) - let backward_outputs = - T.RegionGroupId.Map.find bid ctx.backward_outputs - in - let field_values = List.map mk_texpression_from_var backward_outputs in - (* Backward functions never return a state *) - (* TODO: we should use a [fail] function, it would be cleaner *) - let ret_value = mk_simpl_tuple_texpression field_values in - let ret_value = mk_result_return_texpression ret_value in - ret_value + let field_values = List.map mk_texpression_from_var backward_outputs in + mk_simpl_tuple_texpression field_values + in + (* We may need to return a state + * - error-monad: Return x + * - state-error: Return (state, x) + * *) + let effect_info = ctx.sg.info.effect_info in + let output = + if effect_info.stateful then + let state_rvalue = mk_state_texpression ctx.state_var in + mk_simpl_tuple_texpression [ state_rvalue; output ] + else output + in + (* Wrap in a result - TODO: check effect_info.can_fail to not always wrap *) + (* TODO: we should use a [Return] function *) + mk_result_return_texpression output and translate_function_call (config : config) (call : S.call) (e : S.expression) (ctx : bs_ctx) : texpression = @@ -1171,29 +1236,26 @@ and translate_function_call (config : config) (call : S.call) (e : S.expression) (* Retrieve the effect information about this function (can fail, * takes a state as input, etc.) *) let effect_info = - get_fun_effect_info ctx.fun_context.fun_infos fid None - in - (* Add the state input argument *) - let args = - if effect_info.input_state then - let state_var = { e = Var ctx.state_var; ty = mk_state_ty } in - List.append args [ state_var ] - else args + get_fun_effect_info config.backward_no_state_update + ctx.fun_context.fun_infos fid None in - (* Generate a fresh state variable if the function call introduces - * a new variable *) - let ctx, out_state = - if effect_info.input_state then - let ctx, var = bs_ctx_fresh_state_var ctx in - (ctx, Some var) - else (ctx, None) + (* If the function is stateful: + * - add the state input argument + * - generate a fresh state variable for the returned state + *) + let args, ctx, out_state = + if effect_info.stateful then + let state_var = mk_state_texpression ctx.state_var in + let ctx, nstate_var = bs_ctx_fresh_state_var ctx in + (List.append args [ state_var ], ctx, Some nstate_var) + else (args, ctx, None) in (* Register the function call *) let ctx = bs_ctx_register_forward_call call_id call args ctx in (ctx, func, effect_info, args, out_state) | S.Unop E.Not -> let effect_info = - { can_fail = false; input_state = false; output_state = false } + { can_fail = false; stateful_group = false; stateful = false } in (ctx, Unop Not, effect_info, args, None) | S.Unop E.Neg -> ( @@ -1203,14 +1265,14 @@ and translate_function_call (config : config) (call : S.call) (e : S.expression) (* Note that negation can lead to an overflow and thus fail (it * is thus monadic) *) let effect_info = - { can_fail = true; input_state = false; output_state = false } + { can_fail = true; stateful_group = false; stateful = false } in (ctx, Unop (Neg int_ty), effect_info, args, None) | _ -> raise (Failure "Unreachable")) | S.Unop (E.Cast (src_ty, tgt_ty)) -> (* Note that cast can fail *) let effect_info = - { can_fail = true; input_state = false; output_state = false } + { can_fail = true; stateful_group = false; stateful = false } in (ctx, Unop (Cast (src_ty, tgt_ty)), effect_info, args, None) | S.Binop binop -> ( @@ -1222,8 +1284,8 @@ and translate_function_call (config : config) (call : S.call) (e : S.expression) let effect_info = { can_fail = ExpressionsUtils.binop_can_fail binop; - input_state = false; - output_state = false; + stateful_group = false; + stateful = false; } in (ctx, Binop (binop, int_ty0), effect_info, args, None) @@ -1307,6 +1369,17 @@ and translate_end_abstraction (config : config) (abs : V.abs) (e : S.expression) | V.FunCall -> let call_info = V.FunCallId.Map.find abs.call_id ctx.calls in let call = call_info.forward in + let fun_id = + match call.call_id with + | S.Fun (fun_id, _) -> fun_id + | Unop _ | Binop _ -> + (* Those don't have backward functions *) + raise (Failure "Unreachable") + in + let effect_info = + get_fun_effect_info config.backward_no_state_update + ctx.fun_context.fun_infos fun_id (Some abs.back_id) + in let type_args = List.map (ctx_translate_fwd_ty ctx) call.type_params in (* Retrieve the original call and the parent abstractions *) let _forward, backwards = get_abs_ancestors ctx abs in @@ -1322,8 +1395,21 @@ and translate_end_abstraction (config : config) (abs : V.abs) (e : S.expression) (* Retrieve the values consumed upon ending the loans inside this * abstraction: those give us the remaining input values *) let back_inputs = abs_to_consumed ctx abs in + (* If the function is stateful: + * - add the state input argument + * - generate a fresh state variable for the returned state + *) + let back_state, ctx, nstate = + if effect_info.stateful then + let back_state = mk_state_texpression ctx.state_var in + let ctx, nstate = bs_ctx_fresh_state_var ctx in + ([ back_state ], ctx, Some nstate) + else ([], ctx, None) + in + (* Concatenate all the inpus *) let inputs = - List.concat [ fwd_inputs; back_ancestors_inputs; back_inputs ] + List.concat + [ fwd_inputs; back_ancestors_inputs; back_inputs; back_state ] in (* Retrieve the values given back by this function: those are the output * values. We rely on the fact that there are no nested borrows to use the @@ -1333,43 +1419,42 @@ and translate_end_abstraction (config : config) (abs : V.abs) (e : S.expression) List.append (List.map translate_opt_mplace call.args_places) [ None ] in let ctx, outputs = abs_to_given_back output_mpl abs ctx in - (* Group the output values together (note that for now, backward functions - * never return an output state) *) + (* Group the output values together: first the updated inputs *) let output = mk_simpl_tuple_pattern outputs in - (* Sanity check: the inputs and outputs have the proper number and the proper type *) - let fun_id = - match call.call_id with - | S.Fun (fun_id, _) -> fun_id - | Unop _ | Binop _ -> - (* Those don't have backward functions *) - raise (Failure "Unreachable") + (* Add the returned state if the function is stateful *) + let output = + match nstate with + | None -> output + | Some nstate -> mk_simpl_tuple_pattern [ nstate; output ] in - - let inst_sg = - get_instantiated_fun_sig fun_id (Some abs.back_id) type_args ctx + (* Sanity check: the inputs and outputs have the proper number and the proper type *) + let _ = + let inst_sg = + get_instantiated_fun_sig fun_id (Some abs.back_id) type_args ctx + in + log#ldebug + (lazy + ("\n- fun_id: " ^ A.show_fun_id fun_id ^ "\n- inputs (" + ^ string_of_int (List.length inputs) + ^ "): " + ^ String.concat ", " (List.map show_texpression inputs) + ^ "\n- inst_sg.inputs (" + ^ string_of_int (List.length inst_sg.inputs) + ^ "): " + ^ String.concat ", " (List.map show_ty inst_sg.inputs))); + List.iter + (fun (x, ty) -> assert ((x : texpression).ty = ty)) + (List.combine inputs inst_sg.inputs); + log#ldebug + (lazy + ("\n- outputs: " + ^ string_of_int (List.length outputs) + ^ "\n- expected outputs: " + ^ string_of_int (List.length inst_sg.doutputs))); + List.iter + (fun (x, ty) -> assert ((x : typed_pattern).ty = ty)) + (List.combine outputs inst_sg.doutputs) in - log#ldebug - (lazy - ("\n- fun_id: " ^ A.show_fun_id fun_id ^ "\n- inputs (" - ^ string_of_int (List.length inputs) - ^ "): " - ^ String.concat ", " (List.map show_texpression inputs) - ^ "\n- inst_sg.inputs (" - ^ string_of_int (List.length inst_sg.inputs) - ^ "): " - ^ String.concat ", " (List.map show_ty inst_sg.inputs))); - List.iter - (fun (x, ty) -> assert ((x : texpression).ty = ty)) - (List.combine inputs inst_sg.inputs); - log#ldebug - (lazy - ("\n- outputs: " - ^ string_of_int (List.length outputs) - ^ "\n- expected outputs: " - ^ string_of_int (List.length inst_sg.doutputs))); - List.iter - (fun (x, ty) -> assert ((x : typed_pattern).ty = ty)) - (List.combine outputs inst_sg.doutputs); (* Retrieve the function id, and register the function call in the context * if necessary *) let ctx, func = bs_ctx_register_backward_call abs back_inputs ctx in @@ -1382,9 +1467,6 @@ and translate_end_abstraction (config : config) (abs : V.abs) (e : S.expression) (fun (arg, mp) -> mk_opt_mplace_texpression mp arg) (List.combine inputs args_mplaces) in - let effect_info = - get_fun_effect_info ctx.fun_context.fun_infos fun_id (Some abs.back_id) - in let input_tys = (List.map (fun (x : texpression) -> x.ty)) args in let ret_ty = if effect_info.can_fail then mk_result_ty output.ty else output.ty @@ -1398,8 +1480,13 @@ and translate_end_abstraction (config : config) (abs : V.abs) (e : S.expression) * We do a small optimization here: if the backward function doesn't * have any output, we don't introduce any function call. * See the comment in [config]. + * + * TODO: use an option to disallow backward functions from updating the state. + * TODO: a backward function which only gives back shared borrows shouldn't + * update the state (state updates should only be used for mutable borrows, + * with objects like Rc for instance. *) - if config.filter_useless_back_calls && outputs = [] then ( + if config.filter_useless_back_calls && outputs = [] && nstate = None then ( (* No outputs - we do a small sanity check: the backward function * should have exactly the same number of inputs as the forward: * this number can be different only if the forward function returned @@ -1708,32 +1795,28 @@ let translate_fun_decl (config : config) (ctx : bs_ctx) (* Translate the declaration *) let def_id = def.A.def_id in let basename = def.name in - (* Lookup the signature *) - let signature = bs_ctx_lookup_local_function_sig def_id bid ctx in + (* Retrieve the signature *) + let signature = ctx.sg in (* Translate the body, if there is *) let body = match body with | None -> None | Some body -> let body = translate_expression config body ctx in - (* Sanity check *) - type_check_texpression ctx body; - (* Introduce the input state, if necessary *) let effect_info = - get_fun_effect_info ctx.fun_context.fun_infos (Regular def_id) bid + get_fun_effect_info config.backward_no_state_update + ctx.fun_context.fun_infos (Regular def_id) bid in - let input_state = - if effect_info.input_state then - [ - { - id = ctx.state_var; - basename = Some ConstStrings.state_basename; - ty = mk_state_ty; - }; - ] + (* Sanity check *) + type_check_texpression ctx body; + (* Introduce the forward input state (the state at call site of the + * *forward* function), if necessary. *) + let fwd_state = + (* We check if the *whole group* is stateful. See {!effect_info} *) + if effect_info.stateful_group then [ mk_state_var ctx.state_var ] else [] in - (* Compute the list of (properly ordered) input variables *) + (* Compute the list of (properly ordered) backward input variables *) let backward_inputs : var list = match bid with | None -> [] @@ -1747,8 +1830,17 @@ let translate_fun_decl (config : config) (ctx : bs_ctx) (fun id -> T.RegionGroupId.Map.find id ctx.backward_inputs) backward_ids) in + (* Introduce the backward input state (the state at call site of the + * *backward* function), if necessary *) + let back_state = + if effect_info.stateful && Option.is_some bid then + [ mk_state_var ctx.back_state_var ] + else [] + in + (* Group the inputs together *) let inputs = - List.concat [ ctx.forward_inputs; input_state; backward_inputs ] + List.concat + [ ctx.forward_inputs; fwd_state; backward_inputs; back_state ] in let inputs_lvs = List.map (fun v -> mk_typed_pattern_from_var v None) inputs @@ -1756,12 +1848,18 @@ let translate_fun_decl (config : config) (ctx : bs_ctx) (* Sanity check *) log#ldebug (lazy - ("SymbolicToPure.translate_fun_decl:" ^ "\n- forward_inputs: " + ("SymbolicToPure.translate_fun_decl: " + ^ Print.fun_name_to_string def.A.name + ^ " (" + ^ Print.option_to_string T.RegionGroupId.to_string bid + ^ ")" ^ "\n- forward_inputs: " ^ String.concat ", " (List.map show_var ctx.forward_inputs) - ^ "\n- input_state: " - ^ String.concat ", " (List.map show_var input_state) + ^ "\n- fwd_state: " + ^ String.concat ", " (List.map show_var fwd_state) ^ "\n- backward_inputs: " ^ String.concat ", " (List.map show_var backward_inputs) + ^ "\n- back_state: " + ^ String.concat ", " (List.map show_var back_state) ^ "\n- signature.inputs: " ^ String.concat ", " (List.map show_ty signature.inputs))); assert ( @@ -1804,8 +1902,8 @@ let translate_type_decls (type_decls : T.type_decl list) : type_decl list = - optional names for the outputs values (we derive them for the backward functions) *) -let translate_fun_signatures (fun_infos : FA.fun_info A.FunDeclId.Map.t) - (types_infos : TA.type_infos) +let translate_fun_signatures (backward_no_state_update : bool) + (fun_infos : FA.fun_info A.FunDeclId.Map.t) (types_infos : TA.type_infos) (functions : (A.fun_id * string option list * A.fun_sig) list) : fun_sig_named_outputs RegularFunIdMap.t = (* For every function, translate the signatures of: @@ -1816,7 +1914,8 @@ let translate_fun_signatures (fun_infos : FA.fun_info A.FunDeclId.Map.t) (sg : A.fun_sig) : (regular_fun_id * fun_sig_named_outputs) list = (* The forward function *) let fwd_sg = - translate_fun_sig fun_infos fun_id types_infos sg input_names None + translate_fun_sig backward_no_state_update fun_infos fun_id types_infos sg + input_names None in let fwd_id = (fun_id, None) in (* The backward functions *) @@ -1824,8 +1923,8 @@ let translate_fun_signatures (fun_infos : FA.fun_info A.FunDeclId.Map.t) List.map (fun (rg : T.region_var_group) -> let tsg = - translate_fun_sig fun_infos fun_id types_infos sg input_names - (Some rg.id) + translate_fun_sig backward_no_state_update fun_infos fun_id + types_infos sg input_names (Some rg.id) in let id = (fun_id, Some rg.id) in (id, tsg)) |