diff options
Diffstat (limited to 'compiler/SymbolicToPure.ml')
-rw-r--r-- | compiler/SymbolicToPure.ml | 158 |
1 files changed, 55 insertions, 103 deletions
diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml index 9d249cfb..74bc20ae 100644 --- a/compiler/SymbolicToPure.ml +++ b/compiler/SymbolicToPure.ml @@ -12,45 +12,6 @@ module FA = FunsAnalysis (** The local logger *) let log = L.symbolic_to_pure_log -(* TODO: carrying configurations everywhere is super annoying. - Group everything in references in a [Config.ml] file (put aside the execution - mode, maybe). -*) -type config = { - filter_useless_back_calls : bool; - (** If [true], filter the useless calls to backward functions. - - The useless calls are calls to backward functions which have no outputs. - This case happens if the original Rust function only takes *shared* borrows - as inputs, and is thus pretty common. - - We are allowed to do this only because in this specific case, - the backward function fails *exactly* when the forward function fails - (they actually do exactly the same thing, the only difference being - that the forward function can potentially return a value), and upon - reaching the place where we should introduce a call to the backward - function, we know we have introduced a call to the forward function. - - Also note that in general, backward functions "do more things" than - forward functions, and have more opportunities to fail (even though - in the generated code, calls to the backward functions should fail - exactly when the corresponding, previous call to the forward functions - failed). - - We might want to move this optimization to the micro-passes subsequent - to the translation from symbolic to pure, but it is really super easy - to do it when going from symbolic to pure. - Note that we later filter the useless *forward* calls in the micro-passes, - where it is more natural to do. - *) - backward_no_state_update : bool; - (** Controls whether backward functions update the state, in case we use - a state ({!use_state}). - - See {!Translate.config.backward_no_state_update}. - *) -} - type type_context = { llbc_type_decls : T.type_decl TypeDeclId.Map.t; type_decls : type_decl TypeDeclId.Map.t; @@ -528,15 +489,14 @@ let list_ancestor_abstractions (ctx : bs_ctx) (abs : V.abs) : [backward_no_state_update]: see {!config} *) -let get_fun_effect_info (backward_no_state_update : bool) - (fun_infos : FA.fun_info A.FunDeclId.Map.t) (fun_id : A.fun_id) - (gid : T.RegionGroupId.id option) : fun_effect_info = +let get_fun_effect_info (fun_infos : FA.fun_info A.FunDeclId.Map.t) + (fun_id : A.fun_id) (gid : T.RegionGroupId.id option) : fun_effect_info = match fun_id with | A.Regular fid -> let info = A.FunDeclId.Map.find fid fun_infos in let stateful_group = info.stateful in let stateful = - stateful_group && ((not backward_no_state_update) || gid = None) + stateful_group && ((not !Config.backward_no_state_update) || gid = None) in { can_fail = info.can_fail; stateful_group; stateful } | A.Assumed aid -> @@ -553,9 +513,8 @@ let get_fun_effect_info (backward_no_state_update : bool) name (outputs for backward functions come from borrows in the inputs of the forward function) which we use as hints to generate pretty names. *) -let translate_fun_sig (backward_no_state_update : bool) - (fun_infos : FA.fun_info A.FunDeclId.Map.t) (fun_id : A.fun_id) - (types_infos : TA.type_infos) (sg : A.fun_sig) +let translate_fun_sig (fun_infos : FA.fun_info A.FunDeclId.Map.t) + (fun_id : A.fun_id) (types_infos : TA.type_infos) (sg : A.fun_sig) (input_names : string option list) (bid : T.RegionGroupId.id option) : fun_sig_named_outputs = (* Retrieve the list of parent backward functions *) @@ -606,20 +565,18 @@ let translate_fun_sig (backward_no_state_update : bool) List.filter_map (translate_back_ty_for_gid gid) [ sg.output ] in (* Is the function stateful, and can it fail? *) - let effect_info = - get_fun_effect_info backward_no_state_update fun_infos fun_id bid - in + let effect_info = get_fun_effect_info fun_infos fun_id bid in (* If the function is stateful, the inputs are: - forward: [fwd_ty0, ..., fwd_tyn, state] - backward: - - if config.no_backward_state: [fwd_ty0, ..., fwd_tyn, state, back_ty, state] + - if {!Config.backward_no_state_update}: [fwd_ty0, ..., fwd_tyn, state, back_ty, state] - otherwise: [fwd_ty0, ..., fwd_tyn, state, back_ty] The backward takes the same state as input as the forward function, together with the state at the point where it gets called, if it is stateful. - See the comments for {!Translate.config.backward_no_state_update} + See the comments for {!Config.backward_no_state_update} *) let fwd_state_ty = (* For the forward state, we check if the *whole group* is stateful. @@ -1134,17 +1091,16 @@ let get_abs_ancestors (ctx : bs_ctx) (abs : V.abs) : let abs_ancestors = list_ancestor_abstractions ctx abs in (call_info.forward, abs_ancestors) -let rec translate_expression (config : config) (e : S.expression) (ctx : bs_ctx) - : texpression = +let rec translate_expression (e : S.expression) (ctx : bs_ctx) : texpression = match e with | S.Return opt_v -> translate_return opt_v ctx | Panic -> translate_panic ctx - | FunCall (call, e) -> translate_function_call config call e ctx - | EndAbstraction (abs, e) -> translate_end_abstraction config abs e ctx - | EvalGlobal (gid, sv, e) -> translate_global_eval config gid sv e ctx - | Assertion (v, e) -> translate_assertion config v e ctx - | Expansion (p, sv, exp) -> translate_expansion config p sv exp ctx - | Meta (meta, e) -> translate_meta config meta e ctx + | FunCall (call, e) -> translate_function_call call e ctx + | EndAbstraction (abs, e) -> translate_end_abstraction abs e ctx + | EvalGlobal (gid, sv, e) -> translate_global_eval gid sv e ctx + | Assertion (v, e) -> translate_assertion v e ctx + | Expansion (p, sv, exp) -> translate_expansion p sv exp ctx + | Meta (meta, e) -> translate_meta meta e ctx | ForwardEnd e -> (* Update the current state with the additional state received by the backward function, if needs be *) @@ -1153,7 +1109,7 @@ let rec translate_expression (config : config) (e : S.expression) (ctx : bs_ctx) | None -> ctx | Some _ -> { ctx with state_var = ctx.back_state_var } in - translate_expression config e ctx + translate_expression e ctx and translate_panic (ctx : bs_ctx) : texpression = (* Here we use the function return type - note that it is ok because @@ -1213,8 +1169,8 @@ and translate_return (opt_v : V.typed_value option) (ctx : bs_ctx) : texpression (* TODO: we should use a [Return] function *) mk_result_return_texpression output -and translate_function_call (config : config) (call : S.call) (e : S.expression) - (ctx : bs_ctx) : texpression = +and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) : + texpression = (* Translate the function call *) let type_args = List.map (ctx_translate_fwd_ty ctx) call.type_params in let args = @@ -1236,8 +1192,7 @@ and translate_function_call (config : config) (call : S.call) (e : S.expression) (* Retrieve the effect information about this function (can fail, * takes a state as input, etc.) *) let effect_info = - get_fun_effect_info config.backward_no_state_update - ctx.fun_context.fun_infos fid None + get_fun_effect_info ctx.fun_context.fun_infos fid None in (* If the function is stateful: * - add the state input argument @@ -1306,12 +1261,12 @@ and translate_function_call (config : config) (call : S.call) (e : S.expression) let func = { e = Qualif func; ty = func_ty } in let call = mk_apps func args in (* Translate the next expression *) - let next_e = translate_expression config e ctx in + let next_e = translate_expression e ctx in (* Put together *) mk_let effect_info.can_fail dest_v call next_e -and translate_end_abstraction (config : config) (abs : V.abs) (e : S.expression) - (ctx : bs_ctx) : texpression = +and translate_end_abstraction (abs : V.abs) (e : S.expression) (ctx : bs_ctx) : + texpression = log#ldebug (lazy ("translate_end_abstraction: abstraction kind: " @@ -1359,7 +1314,7 @@ and translate_end_abstraction (config : config) (abs : V.abs) (e : S.expression) (fun (var, v) -> assert ((var : var).ty = (v : texpression).ty)) variables_values; (* Translate the next expression *) - let next_e = translate_expression config e ctx in + let next_e = translate_expression e ctx in (* Generate the assignemnts *) let monadic = false in List.fold_right @@ -1377,8 +1332,7 @@ and translate_end_abstraction (config : config) (abs : V.abs) (e : S.expression) raise (Failure "Unreachable") in let effect_info = - get_fun_effect_info config.backward_no_state_update - ctx.fun_context.fun_infos fun_id (Some abs.back_id) + get_fun_effect_info ctx.fun_context.fun_infos fun_id (Some abs.back_id) in let type_args = List.map (ctx_translate_fwd_ty ctx) call.type_params in (* Retrieve the original call and the parent abstractions *) @@ -1459,7 +1413,7 @@ and translate_end_abstraction (config : config) (abs : V.abs) (e : S.expression) * if necessary *) let ctx, func = bs_ctx_register_backward_call abs back_inputs ctx in (* Translate the next expression *) - let next_e = translate_expression config e ctx in + let next_e = translate_expression e ctx in (* Put everything together *) let args_mplaces = List.map (fun _ -> None) inputs in let args = @@ -1479,14 +1433,15 @@ and translate_end_abstraction (config : config) (abs : V.abs) (e : S.expression) * ================= * We do a small optimization here: if the backward function doesn't * have any output, we don't introduce any function call. - * See the comment in [config]. + * See the comment in {!Config.filter_useless_monadic_calls}. * * TODO: use an option to disallow backward functions from updating the state. * TODO: a backward function which only gives back shared borrows shouldn't * update the state (state updates should only be used for mutable borrows, * with objects like Rc for instance. *) - if config.filter_useless_back_calls && outputs = [] && nstate = None then ( + if !Config.filter_useless_monadic_calls && outputs = [] && nstate = None + then ( (* No outputs - we do a small sanity check: the backward function * should have exactly the same number of inputs as the forward: * this number can be different only if the forward function returned @@ -1549,7 +1504,7 @@ and translate_end_abstraction (config : config) (abs : V.abs) (e : S.expression) assert (given_back.ty = input.ty)) given_back_inputs; (* Translate the next expression *) - let next_e = translate_expression config e ctx in + let next_e = translate_expression e ctx in (* Generate the assignments *) let monadic = false in List.fold_right @@ -1557,20 +1512,20 @@ and translate_end_abstraction (config : config) (abs : V.abs) (e : S.expression) mk_let monadic given_back (mk_texpression_from_var input_var) e) given_back_inputs next_e -and translate_global_eval (config : config) (gid : A.GlobalDeclId.id) - (sval : V.symbolic_value) (e : S.expression) (ctx : bs_ctx) : texpression = +and translate_global_eval (gid : A.GlobalDeclId.id) (sval : V.symbolic_value) + (e : S.expression) (ctx : bs_ctx) : texpression = let ctx, var = fresh_var_for_symbolic_value sval ctx in let decl = A.GlobalDeclId.Map.find gid ctx.global_context.llbc_global_decls in let global_expr = { id = Global gid; type_args = [] } in (* We use translate_fwd_ty to translate the global type *) let ty = ctx_translate_fwd_ty ctx decl.ty in let gval = { e = Qualif global_expr; ty } in - let e = translate_expression config e ctx in + let e = translate_expression e ctx in mk_let false (mk_typed_pattern_from_var var None) gval e -and translate_assertion (config : config) (v : V.typed_value) (e : S.expression) - (ctx : bs_ctx) : texpression = - let next_e = translate_expression config e ctx in +and translate_assertion (v : V.typed_value) (e : S.expression) (ctx : bs_ctx) : + texpression = + let next_e = translate_expression e ctx in let monadic = true in let v = typed_value_to_texpression ctx v in let args = [ v ] in @@ -1580,8 +1535,8 @@ and translate_assertion (config : config) (v : V.typed_value) (e : S.expression) let assertion = mk_apps func args in mk_let monadic (mk_dummy_pattern mk_unit_ty) assertion next_e -and translate_expansion (config : config) (p : S.mplace option) - (sv : V.symbolic_value) (exp : S.expansion) (ctx : bs_ctx) : texpression = +and translate_expansion (p : S.mplace option) (sv : V.symbolic_value) + (exp : S.expansion) (ctx : bs_ctx) : texpression = (* Translate the scrutinee *) let scrutinee_var = lookup_var_for_symbolic_value sv ctx in let scrutinee = mk_texpression_from_var scrutinee_var in @@ -1598,7 +1553,7 @@ and translate_expansion (config : config) (p : S.mplace option) (* The (mut/shared) borrow type is extracted to identity: we thus simply * introduce an reassignment *) let ctx, var = fresh_var_for_symbolic_value nsv ctx in - let next_e = translate_expression config e ctx in + let next_e = translate_expression e ctx in let monadic = false in mk_let monadic (mk_typed_pattern_from_var var None) @@ -1615,7 +1570,7 @@ and translate_expansion (config : config) (p : S.mplace option) (* There is exactly one branch: no branching *) let type_id, _, _ = TypesUtils.ty_as_adt sv.V.sv_ty in let ctx, vars = fresh_vars_for_symbolic_values svl ctx in - let branch = translate_expression config branch ctx in + let branch = translate_expression branch ctx in match type_id with | T.AdtId adt_id -> (* Detect if this is an enumeration or not *) @@ -1706,7 +1661,7 @@ and translate_expansion (config : config) (p : S.mplace option) in let pat_ty = scrutinee.ty in let pat = mk_adt_pattern pat_ty variant_id vars in - let branch = translate_expression config branch ctx in + let branch = translate_expression branch ctx in { pat; branch } in let branches = @@ -1727,8 +1682,8 @@ and translate_expansion (config : config) (p : S.mplace option) | ExpandBool (true_e, false_e) -> (* We don't need to update the context: we don't introduce any * new values/variables *) - let true_e = translate_expression config true_e ctx in - let false_e = translate_expression config false_e ctx in + let true_e = translate_expression true_e ctx in + let false_e = translate_expression false_e ctx in let e = Switch ( mk_opt_mplace_texpression scrutinee_mplace scrutinee, @@ -1742,12 +1697,12 @@ and translate_expansion (config : config) (p : S.mplace option) match_branch = (* We don't need to update the context: we don't introduce any * new values/variables *) - let branch = translate_expression config branch_e ctx in + let branch = translate_expression branch_e ctx in let pat = mk_typed_pattern_from_primitive_value (PV.Scalar v) in { pat; branch } in let branches = List.map translate_branch branches in - let otherwise = translate_expression config otherwise ctx in + let otherwise = translate_expression otherwise ctx in let pat_ty = Integer int_ty in let otherwise_pat : typed_pattern = { value = PatDummy; ty = pat_ty } in let otherwise : match_branch = @@ -1764,9 +1719,9 @@ and translate_expansion (config : config) (p : S.mplace option) List.for_all (fun (br : match_branch) -> br.branch.ty = ty) branches); { e; ty } -and translate_meta (config : config) (meta : S.meta) (e : S.expression) - (ctx : bs_ctx) : texpression = - let next_e = translate_expression config e ctx in +and translate_meta (meta : S.meta) (e : S.expression) (ctx : bs_ctx) : + texpression = + let next_e = translate_expression e ctx in let meta = match meta with | S.Assignment (lp, rv, rp) -> @@ -1779,8 +1734,7 @@ and translate_meta (config : config) (meta : S.meta) (e : S.expression) let ty = next_e.ty in { e; ty } -let translate_fun_decl (config : config) (ctx : bs_ctx) - (body : S.expression option) : fun_decl = +let translate_fun_decl (ctx : bs_ctx) (body : S.expression option) : fun_decl = (* Translate *) let def = ctx.fun_decl in let bid = ctx.bid in @@ -1802,10 +1756,9 @@ let translate_fun_decl (config : config) (ctx : bs_ctx) match body with | None -> None | Some body -> - let body = translate_expression config body ctx in + let body = translate_expression body ctx in let effect_info = - get_fun_effect_info config.backward_no_state_update - ctx.fun_context.fun_infos (Regular def_id) bid + get_fun_effect_info ctx.fun_context.fun_infos (Regular def_id) bid in (* Sanity check *) type_check_texpression ctx body; @@ -1902,8 +1855,8 @@ let translate_type_decls (type_decls : T.type_decl list) : type_decl list = - optional names for the outputs values (we derive them for the backward functions) *) -let translate_fun_signatures (backward_no_state_update : bool) - (fun_infos : FA.fun_info A.FunDeclId.Map.t) (types_infos : TA.type_infos) +let translate_fun_signatures (fun_infos : FA.fun_info A.FunDeclId.Map.t) + (types_infos : TA.type_infos) (functions : (A.fun_id * string option list * A.fun_sig) list) : fun_sig_named_outputs RegularFunIdMap.t = (* For every function, translate the signatures of: @@ -1914,8 +1867,7 @@ let translate_fun_signatures (backward_no_state_update : bool) (sg : A.fun_sig) : (regular_fun_id * fun_sig_named_outputs) list = (* The forward function *) let fwd_sg = - translate_fun_sig backward_no_state_update fun_infos fun_id types_infos sg - input_names None + translate_fun_sig fun_infos fun_id types_infos sg input_names None in let fwd_id = (fun_id, None) in (* The backward functions *) @@ -1923,8 +1875,8 @@ let translate_fun_signatures (backward_no_state_update : bool) List.map (fun (rg : T.region_var_group) -> let tsg = - translate_fun_sig backward_no_state_update fun_infos fun_id - types_infos sg input_names (Some rg.id) + translate_fun_sig fun_infos fun_id types_infos sg input_names + (Some rg.id) in let id = (fun_id, Some rg.id) in (id, tsg)) |