diff options
Diffstat (limited to '')
| -rw-r--r-- | src/PureMicroPasses.ml | 74 | ||||
| -rw-r--r-- | src/SymbolicToPure.ml | 103 | ||||
| -rw-r--r-- | src/Translate.ml | 42 | ||||
| -rw-r--r-- | src/main.ml | 5 | 
4 files changed, 165 insertions, 59 deletions
| diff --git a/src/PureMicroPasses.ml b/src/PureMicroPasses.ml index 59871600..7094d885 100644 --- a/src/PureMicroPasses.ml +++ b/src/PureMicroPasses.ml @@ -47,6 +47,20 @@ type config = {            See the comments for [expression_contains_child_call_in_all_paths]            for additional explanations. +           +          TODO: rename to [filter_useless_monadic_calls] +       *) +  filter_useless_functions : bool; +      (** If [filter_unused_monadic_calls] is activated, some functions +          become useless: if this option is true, we don't extract them. + +          The calls to functions which always get filtered are: +          - the forward functions with unit return value +          - the backward functions which don't output anything (backward +            functions coming from rust functions with no mutable borrows +            as input values - note that if a function doesn't take mutable +            borrows as inputs, it can't return mutable borrows; we actually +            dynamically check for that).         *)    add_unit_args : bool;        (** Add unit input arguments to functions with no arguments. *) @@ -612,11 +626,47 @@ let filter_unused (filter_monadic_calls : bool) (ctx : trans_ctx)    { def with body; inputs_lvs }  (** Return `None` if the function is a backward function with no outputs (so -    that we eliminate the definition which is useless) *) -let filter_if_backward_with_no_outputs (def : fun_def) : fun_def option = -  if Option.is_some def.back_id && def.signature.outputs = [] then None +    that we eliminate the definition which is useless). + +    Note that the calls to such functions are filtered when translating from +    symbolic to pure. Here, we remove the definitions altogether, because they +    are now useless +  *) +let filter_if_backward_with_no_outputs (config : config) (def : fun_def) : +    fun_def option = +  if +    config.filter_useless_functions && Option.is_some def.back_id +    && def.signature.outputs = [] +  then None    else Some def +(** Return `false` if the forward function is useless and should be filtered. + +    - a forward function with no output (comes from a Rust function with +      unit return type) +    - the function has mutable borrows as inputs (which is materialized +      by the fact we generated backward functions which were not filtered). + +    In such situation, every call to the Rust function will be translated to: +    - a call to the forward function which returns nothing +    - calls to the backward functions +    As a failing backward function implies the forward function also fails, +    we can filter the calls to the forward function, which thus becomes +    useless. +    In such situation, we can remove the forward function definition +    altogether. +  *) +let keep_forward (config : config) (trans : pure_fun_translation) : bool = +  let fwd, backs = trans in +  (* Note that at this point, the output types are no longer seen as tuples: +   * they should be lists of length 1. *) +  if +    config.filter_useless_functions +    && fwd.signature.outputs = [ mk_result_ty unit_ty ] +    && backs <> [] +  then false +  else true +  (** Add unit arguments (optionally) to functions with no arguments, and      change their output type to use `result`    *) @@ -852,7 +902,7 @@ let apply_passes_to_def (config : config) (ctx : trans_ctx) (def : fun_def) :     * Note that the calls to those functions should already have been removed,     * when translating from symbolic to pure. Here, we remove the definitions     * altogether, because they are now useless *) -  let def = filter_if_backward_with_no_outputs def in +  let def = filter_if_backward_with_no_outputs config def in    match def with    | None -> None @@ -924,9 +974,21 @@ let apply_passes_to_def (config : config) (ctx : trans_ctx) (def : fun_def) :        (* We are done *)        Some def +(** Return the forward/backward translations on which we applied the micro-passes. + +    Also returns a boolean indicating whether the forward function should be kept +    or not (because useful/useless - `true` means we need to keep the forward +    function). +    Note that we don't "filter" the forward function and return a boolean instead, +    because this function contains useful information to extract the backward +    functions: keeping it is not necessary but more convenient. + *)  let apply_passes_to_pure_fun_translation (config : config) (ctx : trans_ctx) -    (trans : pure_fun_translation) : pure_fun_translation = +    (trans : pure_fun_translation) : bool * pure_fun_translation = +  (* Apply the passes to the individual functions *)    let forward, backwards = trans in    let forward = Option.get (apply_passes_to_def config ctx forward) in    let backwards = List.filter_map (apply_passes_to_def config ctx) backwards in -  (forward, backwards) +  let trans = (forward, backwards) in +  (* Compute whether we need to filter the forward function or not *) +  (keep_forward config trans, trans) diff --git a/src/SymbolicToPure.ml b/src/SymbolicToPure.ml index ca214d7c..f2ed1053 100644 --- a/src/SymbolicToPure.ml +++ b/src/SymbolicToPure.ml @@ -12,6 +12,34 @@ module PP = PrintPure  (** The local logger *)  let log = L.symbolic_to_pure_log +type config = { +  filter_useless_back_calls : bool; +      (** If `true`, filter the useless calls to backward functions. +        +          The useless calls are calls to backward functions which have no outputs. +          This case happens if the original Rust function only takes *shared* borrows +          as inputs, and is thus pretty common. + +          We are allowed to do this only because in this specific case, +          the backward function fails *exactly* when the forward function fails +          (they actually do exactly the same thing, the only difference being +          that the forward function can potentially return a value), and upon +          reaching the place where we should introduce a call to the backward +          function, we know we have introduced a call to the forward function. +           +          Also note that in general, backward functions "do more things" than +          forward functions, and have more opportunities to fail (even though +          in the generated code, backward functions should fail exactly when +          the forward functions fail). +           +          We might want to move this optimization to the micro-passes subsequent +          to the translation from symbolic to pure, but it is really super easy +          to do it when going from symbolic to pure. +          Note that we later filter the useless *forward* calls in the micro-passes, +          where it is more natural to do. +       *) +} +  type type_context = {    cfim_type_defs : T.type_def TypeDefId.Map.t;    type_defs : type_def TypeDefId.Map.t; @@ -915,9 +943,10 @@ let fun_is_monadic (fun_id : A.fun_id) : bool =    | A.Local _ -> true    | A.Assumed aid -> Assumed.assumed_is_monadic aid -let rec translate_expression (e : S.expression) (ctx : bs_ctx) : texpression = +let rec translate_expression (config : config) (e : S.expression) (ctx : bs_ctx) +    : texpression =    match e with -  | S.Return opt_v -> translate_return opt_v ctx +  | S.Return opt_v -> translate_return config opt_v ctx    | Panic ->        (* Here we use the function return type - note that it is ok because         * we don't match on panics which happen inside the function body - @@ -926,13 +955,13 @@ let rec translate_expression (e : S.expression) (ctx : bs_ctx) : texpression =        let e = Value (v, None) in        let ty = v.ty in        { e; ty } -  | FunCall (call, e) -> translate_function_call call e ctx -  | EndAbstraction (abs, e) -> translate_end_abstraction abs e ctx -  | Expansion (p, sv, exp) -> translate_expansion p sv exp ctx -  | Meta (meta, e) -> translate_meta meta e ctx +  | FunCall (call, e) -> translate_function_call config call e ctx +  | EndAbstraction (abs, e) -> translate_end_abstraction config abs e ctx +  | Expansion (p, sv, exp) -> translate_expansion config p sv exp ctx +  | Meta (meta, e) -> translate_meta config meta e ctx -and translate_return (opt_v : V.typed_value option) (ctx : bs_ctx) : texpression -    = +and translate_return (_config : config) (opt_v : V.typed_value option) +    (ctx : bs_ctx) : texpression =    (* There are two cases:       - either we are translating a forward function, in which case the optional         value should be `Some` (it is the returned value) @@ -964,8 +993,8 @@ and translate_return (opt_v : V.typed_value option) (ctx : bs_ctx) : texpression        let ty = ret_value.ty in        { e; ty } -and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) : -    texpression = +and translate_function_call (config : config) (call : S.call) (e : S.expression) +    (ctx : bs_ctx) : texpression =    (* Translate the function call *)    let type_params = List.map (ctx_translate_fwd_ty ctx) call.type_params in    let args = List.map (typed_value_to_rvalue ctx) call.args in @@ -1011,12 +1040,12 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) :    let call_ty = if monadic then mk_result_ty dest_v.ty else dest_v.ty in    let call = { e = call; ty = call_ty } in    (* Translate the next expression *) -  let next_e = translate_expression e ctx in +  let next_e = translate_expression config e ctx in    (* Put together *)    mk_let monadic dest_v call next_e -and translate_end_abstraction (abs : V.abs) (e : S.expression) (ctx : bs_ctx) : -    texpression = +and translate_end_abstraction (config : config) (abs : V.abs) (e : S.expression) +    (ctx : bs_ctx) : texpression =    log#ldebug      (lazy        ("translate_end_abstraction: abstraction kind: " @@ -1064,7 +1093,7 @@ and translate_end_abstraction (abs : V.abs) (e : S.expression) (ctx : bs_ctx) :          (fun (var, v) -> assert ((var : var).ty = (v : typed_rvalue).ty))          variables_values;        (* Translate the next expression *) -      let next_e = translate_expression e ctx in +      let next_e = translate_expression config e ctx in        (* Generate the assignemnts *)        let monadic = false in        List.fold_right @@ -1129,7 +1158,7 @@ and translate_end_abstraction (abs : V.abs) (e : S.expression) (ctx : bs_ctx) :         * if necessary *)        let ctx, func = bs_ctx_register_backward_call abs ctx in        (* Translate the next expression *) -      let next_e = translate_expression e ctx in +      let next_e = translate_expression config e ctx in        (* Put everything together *)        let args_mplaces = List.map (fun _ -> None) inputs in        let args = @@ -1144,17 +1173,10 @@ and translate_end_abstraction (abs : V.abs) (e : S.expression) (ctx : bs_ctx) :        (* **Optimization**:         * =================         * We do a small optimization here: if the backward function doesn't -       * have any output, we don't introduce any function call. This case -       * happens if the function only takes *shared* borrows as inputs, -       * and is thus pretty common. We might want to move the optimization -       * to the micro-passes code, but it is really super easy to do it -       * here. Note that we are allowed to do it only because in this case, -       * the backward function *fails exactly when the forward function fails* -       * (they actually do exactly the same thing, the only difference being -       * that the forward function can potentially return a value), and we -       * know that we called the forward function before. +       * have any output, we don't introduce any function call. +       * See the comment in [config].         *) -      if outputs = [] then ( +      if config.filter_useless_back_calls && outputs = [] then (          (* No outputs - we do a small sanity check: the backward function           * should have exactly the same number of inputs as the forward:           * this number can be different only if the forward function returned @@ -1218,7 +1240,7 @@ and translate_end_abstraction (abs : V.abs) (e : S.expression) (ctx : bs_ctx) :            assert (given_back.ty = input.ty))          given_back_inputs;        (* Translate the next expression *) -      let next_e = translate_expression e ctx in +      let next_e = translate_expression config e ctx in        (* Generate the assignments *)        let monadic = false in        List.fold_right @@ -1228,8 +1250,8 @@ and translate_end_abstraction (abs : V.abs) (e : S.expression) (ctx : bs_ctx) :              e)          given_back_inputs next_e -and translate_expansion (p : S.mplace option) (sv : V.symbolic_value) -    (exp : S.expansion) (ctx : bs_ctx) : texpression = +and translate_expansion (config : config) (p : S.mplace option) +    (sv : V.symbolic_value) (exp : S.expansion) (ctx : bs_ctx) : texpression =    (* Translate the scrutinee *)    let scrutinee_var = lookup_var_for_symbolic_value sv ctx in    let scrutinee = mk_typed_rvalue_from_var scrutinee_var in @@ -1246,7 +1268,7 @@ and translate_expansion (p : S.mplace option) (sv : V.symbolic_value)            (* The (mut/shared) borrow type is extracted to identity: we thus simply             * introduce an reassignment *)            let ctx, var = fresh_var_for_symbolic_value nsv ctx in -          let next_e = translate_expression e ctx in +          let next_e = translate_expression config e ctx in            let monadic = false in            mk_let monadic              (mk_typed_lvalue_from_var var None) @@ -1263,7 +1285,7 @@ and translate_expansion (p : S.mplace option) (sv : V.symbolic_value)            (* There is exactly one branch: no branching *)            let type_id, _, _ = TypesUtils.ty_as_adt sv.V.sv_ty in            let ctx, vars = fresh_vars_for_symbolic_values svl ctx in -          let branch = translate_expression branch ctx in +          let branch = translate_expression config branch ctx in            match type_id with            | T.AdtId adt_id ->                (* Detect if this is an enumeration or not *) @@ -1349,7 +1371,7 @@ and translate_expansion (p : S.mplace option) (sv : V.symbolic_value)              in              let pat_ty = scrutinee.ty in              let pat = mk_adt_lvalue pat_ty variant_id vars in -            let branch = translate_expression branch ctx in +            let branch = translate_expression config branch ctx in              { pat; branch }            in            let branches = @@ -1367,8 +1389,8 @@ and translate_expansion (p : S.mplace option) (sv : V.symbolic_value)    | ExpandBool (true_e, false_e) ->        (* We don't need to update the context: we don't introduce any         * new values/variables *) -      let true_e = translate_expression true_e ctx in -      let false_e = translate_expression false_e ctx in +      let true_e = translate_expression config true_e ctx in +      let false_e = translate_expression config false_e ctx in        let e =          Switch            (mk_value_expression scrutinee scrutinee_mplace, If (true_e, false_e)) @@ -1381,12 +1403,12 @@ and translate_expansion (p : S.mplace option) (sv : V.symbolic_value)            match_branch =          (* We don't need to update the context: we don't introduce any           * new values/variables *) -        let branch = translate_expression branch_e ctx in +        let branch = translate_expression config branch_e ctx in          let pat = mk_typed_lvalue_from_constant_value (V.Scalar v) in          { pat; branch }        in        let branches = List.map translate_branch branches in -      let otherwise = translate_expression otherwise ctx in +      let otherwise = translate_expression config otherwise ctx in        let pat_ty = Integer int_ty in        let otherwise_pat : typed_lvalue = { value = LvVar Dummy; ty = pat_ty } in        let otherwise : match_branch = @@ -1402,9 +1424,9 @@ and translate_expansion (p : S.mplace option) (sv : V.symbolic_value)          List.for_all (fun (br : match_branch) -> br.branch.ty = ty) branches);        { e; ty } -and translate_meta (meta : S.meta) (e : S.expression) (ctx : bs_ctx) : -    texpression = -  let next_e = translate_expression e ctx in +and translate_meta (config : config) (meta : S.meta) (e : S.expression) +    (ctx : bs_ctx) : texpression = +  let next_e = translate_expression config e ctx in    let meta =      match meta with      | S.Assignment (p, rv) -> @@ -1416,7 +1438,8 @@ and translate_meta (meta : S.meta) (e : S.expression) (ctx : bs_ctx) :    let ty = next_e.ty in    { e; ty } -let translate_fun_def (ctx : bs_ctx) (body : S.expression) : fun_def = +let translate_fun_def (config : config) (ctx : bs_ctx) (body : S.expression) : +    fun_def =    let def = ctx.fun_def in    let bid = ctx.bid in    log#ldebug @@ -1431,7 +1454,7 @@ let translate_fun_def (ctx : bs_ctx) (body : S.expression) : fun_def =    let def_id = def.A.def_id in    let basename = def.name in    let signature = bs_ctx_lookup_local_function_sig def_id bid ctx in -  let body = translate_expression body ctx in +  let body = translate_expression config body ctx in    (* Compute the list of (properly ordered) input variables *)    let backward_inputs : var list =      match bid with diff --git a/src/Translate.ml b/src/Translate.ml index 3781fc33..d51ec826 100644 --- a/src/Translate.ml +++ b/src/Translate.ml @@ -59,7 +59,7 @@ let translate_function_to_symbolics (config : C.partial_config)      TODO: maybe we should introduce a record for this.  *)  let translate_function_to_pure (config : C.partial_config) -    (trans_ctx : trans_ctx) +    (mp_config : Micro.config) (trans_ctx : trans_ctx)      (fun_sigs : SymbolicToPure.fun_sig_named_outputs RegularFunIdMap.t)      (pure_type_defs : Pure.type_def Pure.TypeDefId.Map.t) (fdef : A.fun_def) :      pure_fun_translation = @@ -134,9 +134,17 @@ let translate_function_to_pure (config : C.partial_config)      { ctx with forward_inputs }    in +  (* The symbolic to pure config *) +  let sp_config = +    { +      SymbolicToPure.filter_useless_back_calls = +        mp_config.filter_unused_monadic_calls; +    } +  in +    (* Translate the forward function *)    let pure_forward = -    SymbolicToPure.translate_fun_def +    SymbolicToPure.translate_fun_def sp_config        (add_forward_inputs (fst symbolic_forward) ctx)        (snd symbolic_forward)    in @@ -196,7 +204,7 @@ let translate_function_to_pure (config : C.partial_config)      in      (* Translate *) -    SymbolicToPure.translate_fun_def ctx symbolic +    SymbolicToPure.translate_fun_def sp_config ctx symbolic    in    let pure_backwards =      List.map translate_backward fdef.signature.regions_hierarchy @@ -207,7 +215,7 @@ let translate_function_to_pure (config : C.partial_config)  let translate_module_to_pure (config : C.partial_config)      (mp_config : Micro.config) (m : M.cfim_module) : -    trans_ctx * Pure.type_def list * pure_fun_translation list = +    trans_ctx * Pure.type_def list * (bool * pure_fun_translation) list =    (* Debug *)    log#ldebug (lazy "translate_module_to_pure"); @@ -249,7 +257,8 @@ let translate_module_to_pure (config : C.partial_config)    (* Translate all the functions *)    let pure_translations =      List.map -      (translate_function_to_pure config trans_ctx fun_sigs type_defs_map) +      (translate_function_to_pure config mp_config trans_ctx fun_sigs +         type_defs_map)        m.functions    in @@ -305,7 +314,7 @@ let translate_module (filename : string) (dest_dir : string)    let extract_ctx =      List.fold_left -      (fun extract_ctx def -> +      (fun extract_ctx (_, def) ->          ExtractToFStar.extract_fun_def_register_names extract_ctx def)        extract_ctx trans_funs    in @@ -337,7 +346,8 @@ let translate_module (filename : string) (dest_dir : string)    let trans_funs =      Pure.FunDefId.Map.of_list        (List.map -         (fun ((fd, bdl) : pure_fun_translation) -> (fd.def_id, (fd, bdl))) +         (fun ((keep_fwd, (fd, bdl)) : bool * pure_fun_translation) -> +           (fd.def_id, (keep_fwd, (fd, bdl))))           trans_funs)    in @@ -368,11 +378,16 @@ let translate_module (filename : string) (dest_dir : string)    (* In case of (non-mutually) recursive functions, we use a simple procedure to     * check if the forward and backward functions are mutually recursive.     *) -  let export_functions (is_rec : bool) (pure_ls : pure_fun_translation list) : -      unit = -    (* Generate the function definitions *) +  let export_functions (is_rec : bool) +      (pure_ls : (bool * pure_fun_translation) list) : unit = +    (* Generate the function definitions, filtering the uselss forward +     * functions. *)      let fls = -      List.concat (List.map (fun (fwd, back_ls) -> fwd :: back_ls) pure_ls) +      List.concat +        (List.map +           (fun (keep_fwd, (fwd, back_ls)) -> +             if keep_fwd then fwd :: back_ls else back_ls) +           pure_ls)      in      (* Check if the functions are mutually recursive - this really works       * to check if the forward and backward translations of a single @@ -397,8 +412,9 @@ let translate_module (filename : string) (dest_dir : string)      (* Insert unit tests if necessary *)      if test_unit_functions then        List.iter -        (fun (fwd, _) -> -          ExtractToFStar.extract_unit_test_if_unit_fun extract_ctx fmt fwd) +        (fun (keep_fwd, (fwd, _)) -> +          if keep_fwd then +            ExtractToFStar.extract_unit_test_if_unit_fun extract_ctx fmt fwd)          pure_ls    in diff --git a/src/main.ml b/src/main.ml index 5e652809..17ab6421 100644 --- a/src/main.ml +++ b/src/main.ml @@ -27,6 +27,7 @@ let () =    let decompose_monads = ref false in    let unfold_monads = ref true in    let filter_unused_calls = ref true in +  let filter_useless_functions = ref true in    let test_units = ref false in    let test_trans_units = ref false in @@ -50,6 +51,9 @@ let () =        ( "-filter-unused-calls",          Arg.Set filter_unused_calls,          " Filter the unused function calls, when possible" ); +      ( "-filter-useless-funs", +        Arg.Set filter_useless_functions, +        " Filter the useless forward/backward functions" );        ( "-test-units",          Arg.Set test_units,          " Test the unit functions with the concrete interpreter" ); @@ -142,6 +146,7 @@ let () =            Micro.decompose_monadic_let_bindings = !decompose_monads;            unfold_monadic_let_bindings = !unfold_monads;            filter_unused_monadic_calls = !filter_unused_calls; +          filter_useless_functions = !filter_useless_functions;            add_unit_args = false;          }        in | 
