summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSon Ho2023-12-15 18:54:06 +0100
committerSon Ho2023-12-15 18:54:06 +0100
commit884edaa3ee975626f184249d491f343fc02a66e2 (patch)
tree38aaf96746d6f61de2ef461a2a0add56db561083
parent955fdab55304979ba2d61432ea654241f20abaa4 (diff)
Make progress on updating the code
-rw-r--r--compiler/PureMicroPasses.ml48
-rw-r--r--compiler/SymbolicToPure.ml79
-rw-r--r--compiler/Translate.ml207
3 files changed, 134 insertions, 200 deletions
diff --git a/compiler/PureMicroPasses.ml b/compiler/PureMicroPasses.ml
index 0102b13e..a7c2f154 100644
--- a/compiler/PureMicroPasses.ml
+++ b/compiler/PureMicroPasses.ml
@@ -776,9 +776,11 @@ let inline_useless_var_reassignments (ctx : trans_ctx) (inline_named : bool)
in
{ def with body = Some body }
-(** Given a forward or backward function call, is there, for every execution
+(** For the cases where we split the forward/backward functions.
+
+ Given a forward or backward function call, is there, for every execution
path, a child backward function called later with exactly the same input
- list prefix? We use this to filter useless function calls: if there are
+ list prefix. We use this to filter useless function calls: if there are
such child calls, we can remove this one (in case its outputs are not
used).
We do this check because we can't simply remove function calls whose
@@ -1008,17 +1010,21 @@ let filter_useless (filter_monadic_calls : bool) (ctx : trans_ctx)
* under some conditions. *)
match (filter_monadic_calls, opt_destruct_function_call re) with
| true, Some (Fun (FromLlbc (fid, lp_id, rg_id)), tys, args) ->
- (* We need to check if there is a child call - see
- * the comments for:
- * [expression_contains_child_call_in_all_paths] *)
- let has_child_call =
- expression_contains_child_call_in_all_paths ctx fid lp_id
- rg_id tys args e
- in
- if has_child_call then (* Filter *)
- (e.e, fun _ -> used)
- else (* No child call: don't filter *)
- dont_filter ()
+ (* If we split the forward/backward functions.
+
+ We need to check if there is a child call - see
+ the comments for:
+ [expression_contains_child_call_in_all_paths] *)
+ if not !Config.return_back_funs then
+ let has_child_call =
+ expression_contains_child_call_in_all_paths ctx fid
+ lp_id rg_id tys args e
+ in
+ if has_child_call then (* Filter *)
+ (e.e, fun _ -> used)
+ else (* No child call: don't filter *)
+ dont_filter ()
+ else dont_filter ()
| _ ->
(* Not an LLBC function call or not allowed to filter: we can't filter *)
dont_filter ()
@@ -1509,9 +1515,12 @@ let decompose_loops (def : fun_decl) : fun_decl * fun_decl list =
altogether.
*)
let keep_forward (fwd : fun_and_loops) (backs : fun_and_loops list) : bool =
- (* Note that at this point, the output types are no longer seen as tuples:
- * they should be lists of length 1. *)
- if
+ (* The question of filtering the forward functions arises only if we split
+ the forward/backward functions *)
+ if !Config.return_back_funs then true
+ else if
+ (* Note that at this point, the output types are no longer seen as tuples:
+ * they should be lists of length 1. *)
!Config.filter_useless_functions
&& fwd.f.signature.output = mk_result_ty mk_unit_ty
&& backs <> []
@@ -1957,9 +1966,10 @@ let apply_passes_to_def (ctx : trans_ctx) (def : fun_decl) :
log#ldebug (lazy ("remove_meta:\n\n" ^ fun_decl_to_string ctx def ^ "\n"));
(* Remove the backward functions with no outputs.
- * Note that the calls to those functions should already have been removed,
- * when translating from symbolic to pure. Here, we remove the definitions
- * altogether, because they are now useless *)
+
+ Note that the *calls* to those functions should already have been removed,
+ when translating from symbolic to pure. Here, we remove the definitions
+ altogether, because they are now useless *)
let name = def.name ^ PrintPure.fun_suffix def.loop_id def.back_id in
let opt_def = filter_if_backward_with_no_outputs def in
diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml
index 08f9e950..204fc399 100644
--- a/compiler/SymbolicToPure.ml
+++ b/compiler/SymbolicToPure.ml
@@ -45,7 +45,6 @@ type fun_sig_named_outputs = {
type fun_context = {
llbc_fun_decls : A.fun_decl A.FunDeclId.Map.t;
- fun_sigs : fun_sig_named_outputs RegularFunIdNotLoopMap.t; (** *)
fun_infos : fun_info A.FunDeclId.Map.t;
regions_hierarchies : T.region_var_groups FunIdMap.t;
}
@@ -144,7 +143,11 @@ type bs_ctx = {
a symbolic expansion or upon ending an abstraction, for instance)
we introduce a new variable (with a let-binding).
*)
- var_counter : VarId.generator;
+ var_counter : VarId.generator ref;
+ (** Using a ref to make sure all the variables identifiers are unique.
+ TODO: this is not very clean, and the code was initially written without
+ a reference (and it's shape hasn't changed). We should use DeBruijn indices.
+ *)
state_var : VarId.id;
(** The current state variable, in case the function is stateful *)
back_state_vars : VarId.id RegionGroupId.Map.t;
@@ -1131,13 +1134,14 @@ let translate_fun_sig_from_decomposed (dsg : Pure.decomposed_fun_sig)
let bs_ctx_fresh_state_var (ctx : bs_ctx) : bs_ctx * var * typed_pattern =
(* Generate the fresh variable *)
- let id, var_counter = VarId.fresh ctx.var_counter in
+ let id, var_counter = VarId.fresh !(ctx.var_counter) in
let state_var =
{ id; basename = Some ConstStrings.state_basename; ty = mk_state_ty }
in
let state_pat = mk_typed_pattern_from_var state_var None in
(* Update the context *)
- let ctx = { ctx with var_counter; state_var = id } in
+ ctx.var_counter := var_counter;
+ let ctx = { ctx with state_var = id } in
(* Return *)
(ctx, state_var, state_pat)
@@ -1146,11 +1150,11 @@ let bs_ctx_fresh_state_var (ctx : bs_ctx) : bs_ctx * var * typed_pattern =
let fresh_var_llbc_ty (basename : string option) (ty : T.ty) (ctx : bs_ctx) :
bs_ctx * var =
(* Generate the fresh variable *)
- let id, var_counter = VarId.fresh ctx.var_counter in
+ let id, var_counter = VarId.fresh !(ctx.var_counter) in
let ty = ctx_translate_fwd_ty ctx ty in
let var = { id; basename; ty } in
(* Update the context *)
- let ctx = { ctx with var_counter } in
+ ctx.var_counter := var_counter;
(* Return *)
(ctx, var)
@@ -1184,10 +1188,10 @@ let fresh_named_vars_for_symbolic_values
let fresh_var (basename : string option) (ty : ty) (ctx : bs_ctx) : bs_ctx * var
=
(* Generate the fresh variable *)
- let id, var_counter = VarId.fresh ctx.var_counter in
+ let id, var_counter = VarId.fresh !(ctx.var_counter) in
let var = { id; basename; ty } in
(* Update the context *)
- let ctx = { ctx with var_counter } in
+ ctx.var_counter := var_counter;
(* Return *)
(ctx, var)
@@ -3303,65 +3307,6 @@ let translate_type_decls (ctx : Contexts.decls_ctx) : type_decl list =
List.map (translate_type_decl ctx)
(TypeDeclId.Map.values ctx.type_ctx.type_decls)
-(** Translates function signatures.
-
- Takes as input a list of function information containing:
- - the function id
- - a list of optional names for the inputs
- - the function signature
-
- Returns a map from forward/backward functions identifiers to:
- - translated function signatures
- - optional names for the outputs values (we derive them for the backward
- functions)
- *)
-let translate_fun_signatures (decls_ctx : C.decls_ctx)
- (functions : (A.fun_id * string option list * A.fun_sig) list) :
- fun_sig_named_outputs RegularFunIdNotLoopMap.t =
- (* For every function, translate the signatures of:
- - the forward function
- - the backward functions
- *)
- let translate_one (fun_id : A.fun_id) (input_names : string option list)
- (sg : A.fun_sig) : (regular_fun_id_not_loop * fun_sig_named_outputs) list
- =
- log#ldebug
- (lazy
- ("Translating signature of function: "
- ^ Print.Expressions.fun_id_to_string
- (Print.Contexts.decls_ctx_to_fmt_env decls_ctx)
- fun_id));
- (* Retrieve the regions hierarchy *)
- let regions_hierarchy =
- FunIdMap.find fun_id decls_ctx.fun_ctx.regions_hierarchies
- in
- (* The forward function *)
- let fwd_sg = translate_fun_sig decls_ctx fun_id sg input_names None in
- let fwd_id = (fun_id, None) in
- (* The backward functions *)
- let back_sgs =
- if !Config.return_back_funs then []
- else
- List.map
- (fun (rg : T.region_var_group) ->
- let tsg =
- translate_fun_sig decls_ctx fun_id sg input_names (Some rg.id)
- in
- let id = (fun_id, Some rg.id) in
- (id, tsg))
- regions_hierarchy
- in
- (* Return *)
- (fwd_id, fwd_sg) :: back_sgs
- in
- let translated =
- List.concat
- (List.map (fun (id, names, sg) -> translate_one id names sg) functions)
- in
- List.fold_left
- (fun m (id, sg) -> RegularFunIdNotLoopMap.add id sg m)
- RegularFunIdNotLoopMap.empty translated
-
let translate_trait_decl (ctx : Contexts.decls_ctx) (trait_decl : A.trait_decl)
: trait_decl =
let {
diff --git a/compiler/Translate.ml b/compiler/Translate.ml
index 06d4bd6d..8b221c93 100644
--- a/compiler/Translate.ml
+++ b/compiler/Translate.ml
@@ -6,7 +6,6 @@ open LlbcAst
open Contexts
module SA = SymbolicAst
module Micro = PureMicroPasses
-open PureUtils
open TranslateCore
(** The local logger *)
@@ -43,7 +42,6 @@ let translate_function_to_symbolics (trans_ctx : trans_ctx) (fdef : fun_decl) :
TODO: maybe we should introduce a record for this.
*)
let translate_function_to_pure (trans_ctx : trans_ctx)
- (fun_sigs : SymbolicToPure.fun_sig_named_outputs RegularFunIdNotLoopMap.t)
(pure_type_decls : Pure.type_decl Pure.TypeDeclId.Map.t) (fdef : fun_decl) :
pure_fun_translation_no_loops =
(* Debug *)
@@ -58,13 +56,9 @@ let translate_function_to_pure (trans_ctx : trans_ctx)
(* Convert the symbolic ASTs to pure ASTs: *)
(* Initialize the context *)
- let forward_sig =
- RegularFunIdNotLoopMap.find (FRegular def_id, None) fun_sigs
- in
let sv_to_var = SymbolicValueId.Map.empty in
let var_counter = Pure.VarId.generator_zero in
let state_var, var_counter = Pure.VarId.fresh var_counter in
- let back_state_var, var_counter = Pure.VarId.fresh var_counter in
let fuel0, var_counter = Pure.VarId.fresh var_counter in
let fuel, var_counter = Pure.VarId.fresh var_counter in
let calls = FunCallId.Map.empty in
@@ -89,7 +83,6 @@ let translate_function_to_pure (trans_ctx : trans_ctx)
let fun_context =
{
SymbolicToPure.llbc_fun_decls = trans_ctx.fun_ctx.fun_decls;
- fun_sigs;
fun_infos = trans_ctx.fun_ctx.fun_infos;
regions_hierarchies = trans_ctx.fun_ctx.regions_hierarchies;
}
@@ -126,17 +119,45 @@ let translate_function_to_pure (trans_ctx : trans_ctx)
!m
in
+ let input_names =
+ match fdef.body with
+ | None -> List.map (fun _ -> None) fdef.signature.inputs
+ | Some body ->
+ List.map
+ (fun (v : var) -> v.name)
+ (LlbcAstUtils.fun_body_get_input_vars body)
+ in
+
+ let sg =
+ SymbolicToPure.translate_fun_sig_to_decomposed trans_ctx (FRegular def_id)
+ fdef.signature input_names
+ in
+
+ let regions_hierarchy =
+ LlbcAstUtils.FunIdMap.find (FRegular def_id) fun_context.regions_hierarchies
+ in
+
+ let var_counter, back_state_vars =
+ if !Config.return_back_funs then (var_counter, [])
+ else
+ List.fold_left_map
+ (fun var_counter (region_vars : region_var_group) ->
+ let gid = region_vars.id in
+ let var, var_counter = Pure.VarId.fresh var_counter in
+ (var_counter, (gid, var)))
+ var_counter regions_hierarchy
+ in
+ let back_state_vars = RegionGroupId.Map.of_list back_state_vars in
+
let ctx =
{
SymbolicToPure.bid = None;
- (* Dummy for now *)
- sg = forward_sig.sg;
- fwd_sg = forward_sig.sg;
+ sg;
(* Will need to be updated for the backward functions *)
sv_to_var;
- var_counter;
+ var_counter = ref var_counter;
state_var;
- back_state_var;
+ back_state_vars;
fuel0;
fuel;
type_context;
@@ -146,9 +167,11 @@ let translate_function_to_pure (trans_ctx : trans_ctx)
trait_impls_ctx = trans_ctx.trait_impls_ctx.trait_impls;
fun_decl = fdef;
forward_inputs = [];
- (* Empty for now *)
- backward_inputs = RegionGroupId.Map.empty;
- (* Empty for now *)
+ (* Initialized just below *)
+ backward_inputs_no_state = RegionGroupId.Map.empty;
+ (* Initialized just below *)
+ backward_inputs_with_state = RegionGroupId.Map.empty;
+ (* Initialized just below *)
backward_outputs = RegionGroupId.Map.empty;
loop_backward_outputs = None;
(* Empty for now *)
@@ -180,6 +203,51 @@ let translate_function_to_pure (trans_ctx : trans_ctx)
| _ -> raise (Failure "Unreachable")
in
+ (* Add the backward inputs *)
+ let ctx, backward_inputs_no_state, backward_inputs_with_state =
+ if !Config.return_back_funs then (ctx, [], [])
+ else
+ let ctx, inputs_no_with_state =
+ List.fold_left_map
+ (fun ctx (region_vars : region_var_group) ->
+ let gid = region_vars.id in
+ let back_sg = RegionGroupId.Map.find gid sg.back_sg in
+ let ctx, no_state =
+ SymbolicToPure.fresh_vars back_sg.inputs_no_state ctx
+ in
+ let ctx, with_state =
+ SymbolicToPure.fresh_vars back_sg.inputs ctx
+ in
+ (ctx, ((gid, no_state), (gid, with_state))))
+ ctx regions_hierarchy
+ in
+ let inputs_no_state, inputs_with_state =
+ List.split inputs_no_with_state
+ in
+ (ctx, inputs_no_state, inputs_with_state)
+ in
+ let backward_inputs_no_state =
+ RegionGroupId.Map.of_list backward_inputs_no_state
+ in
+ let backward_inputs_with_state =
+ RegionGroupId.Map.of_list backward_inputs_with_state
+ in
+ let ctx = { ctx with backward_inputs_no_state; backward_inputs_with_state } in
+
+ (* Add the backward outputs *)
+ let ctx, backward_outputs =
+ List.fold_left_map
+ (fun ctx (region_vars : region_var_group) ->
+ let gid = region_vars.id in
+ let back_sg = RegionGroupId.Map.find gid sg.back_sg in
+ let outputs = List.combine back_sg.output_names back_sg.outputs in
+ let ctx, vars = SymbolicToPure.fresh_vars outputs ctx in
+ (ctx, (gid, vars)))
+ ctx regions_hierarchy
+ in
+ let backward_outputs = RegionGroupId.Map.of_list backward_outputs in
+ let ctx = { ctx with backward_outputs } in
+
(* Translate the forward function *)
let pure_forward =
match symbolic_trans with
@@ -187,7 +255,7 @@ let translate_function_to_pure (trans_ctx : trans_ctx)
| Some (_, ast) -> SymbolicToPure.translate_fun_decl ctx (Some ast)
in
- (* Translate the backward functions *)
+ (* Translate the backward functions, if we split the forward and backward functions *)
let translate_backward (rg : region_var_group) : Pure.fun_decl =
(* For the backward inputs/outputs initialization: we use the fact that
* there are no nested borrows for now, and so that the region groups
@@ -197,83 +265,20 @@ let translate_function_to_pure (trans_ctx : trans_ctx)
match symbolic_trans with
| None ->
- (* Initialize the context - note that the ret_ty is not really
- * useful as we don't translate a body *)
- let backward_sg =
- RegularFunIdNotLoopMap.find (FRegular def_id, Some back_id) fun_sigs
- in
- let ctx = { ctx with bid = Some back_id; sg = backward_sg.sg } in
-
+ (* Initialize the context *)
+ let ctx = { ctx with bid = Some back_id } in
(* Translate *)
SymbolicToPure.translate_fun_decl ctx None
| Some (_, symbolic) ->
- (* Finish initializing the context by adding the additional input
- variables required by the backward function.
- *)
- let backward_sg =
- RegularFunIdNotLoopMap.find (FRegular def_id, Some back_id) fun_sigs
- in
- (* We need to ignore the forward inputs, and the state input (if there is) *)
- let backward_inputs =
- let sg = backward_sg.sg in
- (* TODO: *)
- assert (not !Config.return_back_funs);
- (* We need to ignore the forward state and the backward state *)
- let num_forward_inputs =
- sg.info.fwd_info.num_inputs_with_fuel_with_state
- in
- let num_back_inputs =
- match sg.info.back_info with
- | SingleBack (Some info) -> info.num_inputs_no_fuel_no_state
- | _ -> raise (Failure "Unexpected")
- in
- Collections.List.subslice sg.inputs num_forward_inputs
- (num_forward_inputs + num_back_inputs)
- in
- (* As we forbid nested borrows, the additional inputs for the backward
- * functions come from the borrows in the return value of the rust function:
- * we thus use the name "ret" for those inputs *)
- let backward_inputs =
- List.map (fun ty -> (Some "ret", ty)) backward_inputs
- in
- let ctx, backward_inputs =
- SymbolicToPure.fresh_vars backward_inputs ctx
- in
- (* The outputs for the backward functions, however, come from borrows
- * present in the input values of the rust function: for those we reuse
- * the names of the input values. *)
- let backward_outputs =
- List.combine backward_sg.output_names backward_sg.sg.doutputs
- in
- let ctx, backward_outputs =
- SymbolicToPure.fresh_vars backward_outputs ctx
- in
- let backward_inputs =
- RegionGroupId.Map.singleton back_id backward_inputs
- in
- let backward_outputs =
- RegionGroupId.Map.singleton back_id backward_outputs
- in
-
- (* Put everything in the context *)
- let ctx =
- {
- ctx with
- bid = Some back_id;
- sg = backward_sg.sg;
- backward_inputs;
- backward_outputs;
- }
- in
-
+ (* Initialize the context *)
+ let ctx = { ctx with bid = Some back_id } in
(* Translate *)
SymbolicToPure.translate_fun_decl ctx (Some symbolic)
in
- let regions_hierarchy =
- LlbcAstUtils.FunIdMap.find (FRegular fdef.def_id)
- fun_context.regions_hierarchies
+ let pure_backwards =
+ if !Config.return_back_funs then []
+ else List.map translate_backward regions_hierarchy
in
- let pure_backwards = List.map translate_backward regions_hierarchy in
(* Return *)
(pure_forward, pure_backwards)
@@ -300,36 +305,10 @@ let translate_crate_to_pure (crate : crate) :
(List.map (fun (def : Pure.type_decl) -> (def.def_id, def)) type_decls)
in
- (* Translate all the function *signatures* *)
- let assumed_sigs =
- List.map
- (fun (info : Assumed.assumed_fun_info) ->
- ( FAssumed info.fun_id,
- List.map (fun _ -> None) info.fun_sig.inputs,
- info.fun_sig ))
- Assumed.assumed_fun_infos
- in
- let local_sigs =
- List.map
- (fun (fdef : fun_decl) ->
- let input_names =
- match fdef.body with
- | None -> List.map (fun _ -> None) fdef.signature.inputs
- | Some body ->
- List.map
- (fun (v : var) -> v.name)
- (LlbcAstUtils.fun_body_get_input_vars body)
- in
- (FRegular fdef.def_id, input_names, fdef.signature))
- (FunDeclId.Map.values crate.fun_decls)
- in
- let sigs = List.append assumed_sigs local_sigs in
- let fun_sigs = SymbolicToPure.translate_fun_signatures trans_ctx sigs in
-
(* Translate all the *transparent* functions *)
let pure_translations =
List.map
- (translate_function_to_pure trans_ctx fun_sigs type_decls_map)
+ (translate_function_to_pure trans_ctx type_decls_map)
(FunDeclId.Map.values crate.fun_decls)
in
@@ -1036,7 +1015,7 @@ let translate_crate (filename : string) (dest_dir : string) (crate : crate) :
List.map
(fun { fwd; _ } ->
let fwd_f =
- if fwd.f.Pure.signature.info.effect_info.is_rec then
+ if fwd.f.Pure.signature.fwd_info.effect_info.is_rec then
[ (fwd.f.def_id, None) ]
else []
in