diff options
author | Son Ho | 2023-12-15 18:54:06 +0100 |
---|---|---|
committer | Son Ho | 2023-12-15 18:54:06 +0100 |
commit | 884edaa3ee975626f184249d491f343fc02a66e2 (patch) | |
tree | 38aaf96746d6f61de2ef461a2a0add56db561083 /compiler | |
parent | 955fdab55304979ba2d61432ea654241f20abaa4 (diff) |
Make progress on updating the code
Diffstat (limited to '')
-rw-r--r-- | compiler/PureMicroPasses.ml | 48 | ||||
-rw-r--r-- | compiler/SymbolicToPure.ml | 79 | ||||
-rw-r--r-- | compiler/Translate.ml | 207 |
3 files changed, 134 insertions, 200 deletions
diff --git a/compiler/PureMicroPasses.ml b/compiler/PureMicroPasses.ml index 0102b13e..a7c2f154 100644 --- a/compiler/PureMicroPasses.ml +++ b/compiler/PureMicroPasses.ml @@ -776,9 +776,11 @@ let inline_useless_var_reassignments (ctx : trans_ctx) (inline_named : bool) in { def with body = Some body } -(** Given a forward or backward function call, is there, for every execution +(** For the cases where we split the forward/backward functions. + + Given a forward or backward function call, is there, for every execution path, a child backward function called later with exactly the same input - list prefix? We use this to filter useless function calls: if there are + list prefix. We use this to filter useless function calls: if there are such child calls, we can remove this one (in case its outputs are not used). We do this check because we can't simply remove function calls whose @@ -1008,17 +1010,21 @@ let filter_useless (filter_monadic_calls : bool) (ctx : trans_ctx) * under some conditions. *) match (filter_monadic_calls, opt_destruct_function_call re) with | true, Some (Fun (FromLlbc (fid, lp_id, rg_id)), tys, args) -> - (* We need to check if there is a child call - see - * the comments for: - * [expression_contains_child_call_in_all_paths] *) - let has_child_call = - expression_contains_child_call_in_all_paths ctx fid lp_id - rg_id tys args e - in - if has_child_call then (* Filter *) - (e.e, fun _ -> used) - else (* No child call: don't filter *) - dont_filter () + (* If we split the forward/backward functions. + + We need to check if there is a child call - see + the comments for: + [expression_contains_child_call_in_all_paths] *) + if not !Config.return_back_funs then + let has_child_call = + expression_contains_child_call_in_all_paths ctx fid + lp_id rg_id tys args e + in + if has_child_call then (* Filter *) + (e.e, fun _ -> used) + else (* No child call: don't filter *) + dont_filter () + else dont_filter () | _ -> (* Not an LLBC function call or not allowed to filter: we can't filter *) dont_filter () @@ -1509,9 +1515,12 @@ let decompose_loops (def : fun_decl) : fun_decl * fun_decl list = altogether. *) let keep_forward (fwd : fun_and_loops) (backs : fun_and_loops list) : bool = - (* Note that at this point, the output types are no longer seen as tuples: - * they should be lists of length 1. *) - if + (* The question of filtering the forward functions arises only if we split + the forward/backward functions *) + if !Config.return_back_funs then true + else if + (* Note that at this point, the output types are no longer seen as tuples: + * they should be lists of length 1. *) !Config.filter_useless_functions && fwd.f.signature.output = mk_result_ty mk_unit_ty && backs <> [] @@ -1957,9 +1966,10 @@ let apply_passes_to_def (ctx : trans_ctx) (def : fun_decl) : log#ldebug (lazy ("remove_meta:\n\n" ^ fun_decl_to_string ctx def ^ "\n")); (* Remove the backward functions with no outputs. - * Note that the calls to those functions should already have been removed, - * when translating from symbolic to pure. Here, we remove the definitions - * altogether, because they are now useless *) + + Note that the *calls* to those functions should already have been removed, + when translating from symbolic to pure. Here, we remove the definitions + altogether, because they are now useless *) let name = def.name ^ PrintPure.fun_suffix def.loop_id def.back_id in let opt_def = filter_if_backward_with_no_outputs def in diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml index 08f9e950..204fc399 100644 --- a/compiler/SymbolicToPure.ml +++ b/compiler/SymbolicToPure.ml @@ -45,7 +45,6 @@ type fun_sig_named_outputs = { type fun_context = { llbc_fun_decls : A.fun_decl A.FunDeclId.Map.t; - fun_sigs : fun_sig_named_outputs RegularFunIdNotLoopMap.t; (** *) fun_infos : fun_info A.FunDeclId.Map.t; regions_hierarchies : T.region_var_groups FunIdMap.t; } @@ -144,7 +143,11 @@ type bs_ctx = { a symbolic expansion or upon ending an abstraction, for instance) we introduce a new variable (with a let-binding). *) - var_counter : VarId.generator; + var_counter : VarId.generator ref; + (** Using a ref to make sure all the variables identifiers are unique. + TODO: this is not very clean, and the code was initially written without + a reference (and it's shape hasn't changed). We should use DeBruijn indices. + *) state_var : VarId.id; (** The current state variable, in case the function is stateful *) back_state_vars : VarId.id RegionGroupId.Map.t; @@ -1131,13 +1134,14 @@ let translate_fun_sig_from_decomposed (dsg : Pure.decomposed_fun_sig) let bs_ctx_fresh_state_var (ctx : bs_ctx) : bs_ctx * var * typed_pattern = (* Generate the fresh variable *) - let id, var_counter = VarId.fresh ctx.var_counter in + let id, var_counter = VarId.fresh !(ctx.var_counter) in let state_var = { id; basename = Some ConstStrings.state_basename; ty = mk_state_ty } in let state_pat = mk_typed_pattern_from_var state_var None in (* Update the context *) - let ctx = { ctx with var_counter; state_var = id } in + ctx.var_counter := var_counter; + let ctx = { ctx with state_var = id } in (* Return *) (ctx, state_var, state_pat) @@ -1146,11 +1150,11 @@ let bs_ctx_fresh_state_var (ctx : bs_ctx) : bs_ctx * var * typed_pattern = let fresh_var_llbc_ty (basename : string option) (ty : T.ty) (ctx : bs_ctx) : bs_ctx * var = (* Generate the fresh variable *) - let id, var_counter = VarId.fresh ctx.var_counter in + let id, var_counter = VarId.fresh !(ctx.var_counter) in let ty = ctx_translate_fwd_ty ctx ty in let var = { id; basename; ty } in (* Update the context *) - let ctx = { ctx with var_counter } in + ctx.var_counter := var_counter; (* Return *) (ctx, var) @@ -1184,10 +1188,10 @@ let fresh_named_vars_for_symbolic_values let fresh_var (basename : string option) (ty : ty) (ctx : bs_ctx) : bs_ctx * var = (* Generate the fresh variable *) - let id, var_counter = VarId.fresh ctx.var_counter in + let id, var_counter = VarId.fresh !(ctx.var_counter) in let var = { id; basename; ty } in (* Update the context *) - let ctx = { ctx with var_counter } in + ctx.var_counter := var_counter; (* Return *) (ctx, var) @@ -3303,65 +3307,6 @@ let translate_type_decls (ctx : Contexts.decls_ctx) : type_decl list = List.map (translate_type_decl ctx) (TypeDeclId.Map.values ctx.type_ctx.type_decls) -(** Translates function signatures. - - Takes as input a list of function information containing: - - the function id - - a list of optional names for the inputs - - the function signature - - Returns a map from forward/backward functions identifiers to: - - translated function signatures - - optional names for the outputs values (we derive them for the backward - functions) - *) -let translate_fun_signatures (decls_ctx : C.decls_ctx) - (functions : (A.fun_id * string option list * A.fun_sig) list) : - fun_sig_named_outputs RegularFunIdNotLoopMap.t = - (* For every function, translate the signatures of: - - the forward function - - the backward functions - *) - let translate_one (fun_id : A.fun_id) (input_names : string option list) - (sg : A.fun_sig) : (regular_fun_id_not_loop * fun_sig_named_outputs) list - = - log#ldebug - (lazy - ("Translating signature of function: " - ^ Print.Expressions.fun_id_to_string - (Print.Contexts.decls_ctx_to_fmt_env decls_ctx) - fun_id)); - (* Retrieve the regions hierarchy *) - let regions_hierarchy = - FunIdMap.find fun_id decls_ctx.fun_ctx.regions_hierarchies - in - (* The forward function *) - let fwd_sg = translate_fun_sig decls_ctx fun_id sg input_names None in - let fwd_id = (fun_id, None) in - (* The backward functions *) - let back_sgs = - if !Config.return_back_funs then [] - else - List.map - (fun (rg : T.region_var_group) -> - let tsg = - translate_fun_sig decls_ctx fun_id sg input_names (Some rg.id) - in - let id = (fun_id, Some rg.id) in - (id, tsg)) - regions_hierarchy - in - (* Return *) - (fwd_id, fwd_sg) :: back_sgs - in - let translated = - List.concat - (List.map (fun (id, names, sg) -> translate_one id names sg) functions) - in - List.fold_left - (fun m (id, sg) -> RegularFunIdNotLoopMap.add id sg m) - RegularFunIdNotLoopMap.empty translated - let translate_trait_decl (ctx : Contexts.decls_ctx) (trait_decl : A.trait_decl) : trait_decl = let { diff --git a/compiler/Translate.ml b/compiler/Translate.ml index 06d4bd6d..8b221c93 100644 --- a/compiler/Translate.ml +++ b/compiler/Translate.ml @@ -6,7 +6,6 @@ open LlbcAst open Contexts module SA = SymbolicAst module Micro = PureMicroPasses -open PureUtils open TranslateCore (** The local logger *) @@ -43,7 +42,6 @@ let translate_function_to_symbolics (trans_ctx : trans_ctx) (fdef : fun_decl) : TODO: maybe we should introduce a record for this. *) let translate_function_to_pure (trans_ctx : trans_ctx) - (fun_sigs : SymbolicToPure.fun_sig_named_outputs RegularFunIdNotLoopMap.t) (pure_type_decls : Pure.type_decl Pure.TypeDeclId.Map.t) (fdef : fun_decl) : pure_fun_translation_no_loops = (* Debug *) @@ -58,13 +56,9 @@ let translate_function_to_pure (trans_ctx : trans_ctx) (* Convert the symbolic ASTs to pure ASTs: *) (* Initialize the context *) - let forward_sig = - RegularFunIdNotLoopMap.find (FRegular def_id, None) fun_sigs - in let sv_to_var = SymbolicValueId.Map.empty in let var_counter = Pure.VarId.generator_zero in let state_var, var_counter = Pure.VarId.fresh var_counter in - let back_state_var, var_counter = Pure.VarId.fresh var_counter in let fuel0, var_counter = Pure.VarId.fresh var_counter in let fuel, var_counter = Pure.VarId.fresh var_counter in let calls = FunCallId.Map.empty in @@ -89,7 +83,6 @@ let translate_function_to_pure (trans_ctx : trans_ctx) let fun_context = { SymbolicToPure.llbc_fun_decls = trans_ctx.fun_ctx.fun_decls; - fun_sigs; fun_infos = trans_ctx.fun_ctx.fun_infos; regions_hierarchies = trans_ctx.fun_ctx.regions_hierarchies; } @@ -126,17 +119,45 @@ let translate_function_to_pure (trans_ctx : trans_ctx) !m in + let input_names = + match fdef.body with + | None -> List.map (fun _ -> None) fdef.signature.inputs + | Some body -> + List.map + (fun (v : var) -> v.name) + (LlbcAstUtils.fun_body_get_input_vars body) + in + + let sg = + SymbolicToPure.translate_fun_sig_to_decomposed trans_ctx (FRegular def_id) + fdef.signature input_names + in + + let regions_hierarchy = + LlbcAstUtils.FunIdMap.find (FRegular def_id) fun_context.regions_hierarchies + in + + let var_counter, back_state_vars = + if !Config.return_back_funs then (var_counter, []) + else + List.fold_left_map + (fun var_counter (region_vars : region_var_group) -> + let gid = region_vars.id in + let var, var_counter = Pure.VarId.fresh var_counter in + (var_counter, (gid, var))) + var_counter regions_hierarchy + in + let back_state_vars = RegionGroupId.Map.of_list back_state_vars in + let ctx = { SymbolicToPure.bid = None; - (* Dummy for now *) - sg = forward_sig.sg; - fwd_sg = forward_sig.sg; + sg; (* Will need to be updated for the backward functions *) sv_to_var; - var_counter; + var_counter = ref var_counter; state_var; - back_state_var; + back_state_vars; fuel0; fuel; type_context; @@ -146,9 +167,11 @@ let translate_function_to_pure (trans_ctx : trans_ctx) trait_impls_ctx = trans_ctx.trait_impls_ctx.trait_impls; fun_decl = fdef; forward_inputs = []; - (* Empty for now *) - backward_inputs = RegionGroupId.Map.empty; - (* Empty for now *) + (* Initialized just below *) + backward_inputs_no_state = RegionGroupId.Map.empty; + (* Initialized just below *) + backward_inputs_with_state = RegionGroupId.Map.empty; + (* Initialized just below *) backward_outputs = RegionGroupId.Map.empty; loop_backward_outputs = None; (* Empty for now *) @@ -180,6 +203,51 @@ let translate_function_to_pure (trans_ctx : trans_ctx) | _ -> raise (Failure "Unreachable") in + (* Add the backward inputs *) + let ctx, backward_inputs_no_state, backward_inputs_with_state = + if !Config.return_back_funs then (ctx, [], []) + else + let ctx, inputs_no_with_state = + List.fold_left_map + (fun ctx (region_vars : region_var_group) -> + let gid = region_vars.id in + let back_sg = RegionGroupId.Map.find gid sg.back_sg in + let ctx, no_state = + SymbolicToPure.fresh_vars back_sg.inputs_no_state ctx + in + let ctx, with_state = + SymbolicToPure.fresh_vars back_sg.inputs ctx + in + (ctx, ((gid, no_state), (gid, with_state)))) + ctx regions_hierarchy + in + let inputs_no_state, inputs_with_state = + List.split inputs_no_with_state + in + (ctx, inputs_no_state, inputs_with_state) + in + let backward_inputs_no_state = + RegionGroupId.Map.of_list backward_inputs_no_state + in + let backward_inputs_with_state = + RegionGroupId.Map.of_list backward_inputs_with_state + in + let ctx = { ctx with backward_inputs_no_state; backward_inputs_with_state } in + + (* Add the backward outputs *) + let ctx, backward_outputs = + List.fold_left_map + (fun ctx (region_vars : region_var_group) -> + let gid = region_vars.id in + let back_sg = RegionGroupId.Map.find gid sg.back_sg in + let outputs = List.combine back_sg.output_names back_sg.outputs in + let ctx, vars = SymbolicToPure.fresh_vars outputs ctx in + (ctx, (gid, vars))) + ctx regions_hierarchy + in + let backward_outputs = RegionGroupId.Map.of_list backward_outputs in + let ctx = { ctx with backward_outputs } in + (* Translate the forward function *) let pure_forward = match symbolic_trans with @@ -187,7 +255,7 @@ let translate_function_to_pure (trans_ctx : trans_ctx) | Some (_, ast) -> SymbolicToPure.translate_fun_decl ctx (Some ast) in - (* Translate the backward functions *) + (* Translate the backward functions, if we split the forward and backward functions *) let translate_backward (rg : region_var_group) : Pure.fun_decl = (* For the backward inputs/outputs initialization: we use the fact that * there are no nested borrows for now, and so that the region groups @@ -197,83 +265,20 @@ let translate_function_to_pure (trans_ctx : trans_ctx) match symbolic_trans with | None -> - (* Initialize the context - note that the ret_ty is not really - * useful as we don't translate a body *) - let backward_sg = - RegularFunIdNotLoopMap.find (FRegular def_id, Some back_id) fun_sigs - in - let ctx = { ctx with bid = Some back_id; sg = backward_sg.sg } in - + (* Initialize the context *) + let ctx = { ctx with bid = Some back_id } in (* Translate *) SymbolicToPure.translate_fun_decl ctx None | Some (_, symbolic) -> - (* Finish initializing the context by adding the additional input - variables required by the backward function. - *) - let backward_sg = - RegularFunIdNotLoopMap.find (FRegular def_id, Some back_id) fun_sigs - in - (* We need to ignore the forward inputs, and the state input (if there is) *) - let backward_inputs = - let sg = backward_sg.sg in - (* TODO: *) - assert (not !Config.return_back_funs); - (* We need to ignore the forward state and the backward state *) - let num_forward_inputs = - sg.info.fwd_info.num_inputs_with_fuel_with_state - in - let num_back_inputs = - match sg.info.back_info with - | SingleBack (Some info) -> info.num_inputs_no_fuel_no_state - | _ -> raise (Failure "Unexpected") - in - Collections.List.subslice sg.inputs num_forward_inputs - (num_forward_inputs + num_back_inputs) - in - (* As we forbid nested borrows, the additional inputs for the backward - * functions come from the borrows in the return value of the rust function: - * we thus use the name "ret" for those inputs *) - let backward_inputs = - List.map (fun ty -> (Some "ret", ty)) backward_inputs - in - let ctx, backward_inputs = - SymbolicToPure.fresh_vars backward_inputs ctx - in - (* The outputs for the backward functions, however, come from borrows - * present in the input values of the rust function: for those we reuse - * the names of the input values. *) - let backward_outputs = - List.combine backward_sg.output_names backward_sg.sg.doutputs - in - let ctx, backward_outputs = - SymbolicToPure.fresh_vars backward_outputs ctx - in - let backward_inputs = - RegionGroupId.Map.singleton back_id backward_inputs - in - let backward_outputs = - RegionGroupId.Map.singleton back_id backward_outputs - in - - (* Put everything in the context *) - let ctx = - { - ctx with - bid = Some back_id; - sg = backward_sg.sg; - backward_inputs; - backward_outputs; - } - in - + (* Initialize the context *) + let ctx = { ctx with bid = Some back_id } in (* Translate *) SymbolicToPure.translate_fun_decl ctx (Some symbolic) in - let regions_hierarchy = - LlbcAstUtils.FunIdMap.find (FRegular fdef.def_id) - fun_context.regions_hierarchies + let pure_backwards = + if !Config.return_back_funs then [] + else List.map translate_backward regions_hierarchy in - let pure_backwards = List.map translate_backward regions_hierarchy in (* Return *) (pure_forward, pure_backwards) @@ -300,36 +305,10 @@ let translate_crate_to_pure (crate : crate) : (List.map (fun (def : Pure.type_decl) -> (def.def_id, def)) type_decls) in - (* Translate all the function *signatures* *) - let assumed_sigs = - List.map - (fun (info : Assumed.assumed_fun_info) -> - ( FAssumed info.fun_id, - List.map (fun _ -> None) info.fun_sig.inputs, - info.fun_sig )) - Assumed.assumed_fun_infos - in - let local_sigs = - List.map - (fun (fdef : fun_decl) -> - let input_names = - match fdef.body with - | None -> List.map (fun _ -> None) fdef.signature.inputs - | Some body -> - List.map - (fun (v : var) -> v.name) - (LlbcAstUtils.fun_body_get_input_vars body) - in - (FRegular fdef.def_id, input_names, fdef.signature)) - (FunDeclId.Map.values crate.fun_decls) - in - let sigs = List.append assumed_sigs local_sigs in - let fun_sigs = SymbolicToPure.translate_fun_signatures trans_ctx sigs in - (* Translate all the *transparent* functions *) let pure_translations = List.map - (translate_function_to_pure trans_ctx fun_sigs type_decls_map) + (translate_function_to_pure trans_ctx type_decls_map) (FunDeclId.Map.values crate.fun_decls) in @@ -1036,7 +1015,7 @@ let translate_crate (filename : string) (dest_dir : string) (crate : crate) : List.map (fun { fwd; _ } -> let fwd_f = - if fwd.f.Pure.signature.info.effect_info.is_rec then + if fwd.f.Pure.signature.fwd_info.effect_info.is_rec then [ (fwd.f.def_id, None) ] else [] in |