summaryrefslogtreecommitdiff
path: root/compiler
diff options
context:
space:
mode:
Diffstat (limited to 'compiler')
-rw-r--r--compiler/Config.ml63
-rw-r--r--compiler/Extract.ml8
-rw-r--r--compiler/InterpreterExpressions.mli2
-rw-r--r--compiler/Pure.ml54
-rw-r--r--compiler/PureMicroPasses.ml91
-rw-r--r--compiler/PureUtils.ml23
-rw-r--r--compiler/SymbolicToPure.ml73
-rw-r--r--compiler/Translate.ml8
8 files changed, 246 insertions, 76 deletions
diff --git a/compiler/Config.ml b/compiler/Config.ml
index b09544ba..b8af6c6d 100644
--- a/compiler/Config.ml
+++ b/compiler/Config.ml
@@ -92,6 +92,69 @@ let loop_fixed_point_max_num_iters = 2
(** {1 Translation} *)
+(** If true, do not define separate forward/backward functions, but make the
+ forward functions return the backward function.
+
+ Example:
+ {[
+ (* Rust *)
+ pub fn list_nth<'a, T>(l: &'a mut List<T>, i: u32) -> &'a mut T {
+ match l {
+ List::Nil => {
+ panic!()
+ }
+ List::Cons(x, tl) => {
+ if i == 0 {
+ x
+ } else {
+ list_nth(tl, i - 1)
+ }
+ }
+ }
+ }
+
+ (* Translation, if return_back_funs = false *)
+ def list_nth (T : Type) (l : List T) (i : U32) : Result T :=
+ match l with
+ | List.Cons x tl =>
+ if i = 0#u32
+ then Result.ret x
+ else do
+ let i0 ← i - 1#u32
+ list_nth T tl i0
+ | List.Nil => Result.fail .panic
+
+ def list_nth_back
+ (T : Type) (l : List T) (i : U32) (ret : T) : Result (List T) :=
+ match l with
+ | List.Cons x tl =>
+ if i = 0#u32
+ then Result.ret (List.Cons ret tl)
+ else
+ do
+ let i0 ← i - 1#u32
+ let tl0 ← list_nth_back T tl i0 ret
+ Result.ret (List.Cons x tl0)
+ | List.Nil => Result.fail .panic
+
+ (* Translation, if return_back_funs = true *)
+ def list_nth (T: Type) (ls : List T) (i : U32) :
+ Result (T × (T → Result (List T))) :=
+ match ls with
+ | List.Cons x tl =>
+ if i = 0#u32
+ then Result.ret (x, (λ ret => return (ret :: ls)))
+ else do
+ let i0 ← i - 1#u32
+ let (x, back) ← list_nth ls i0
+ Return.ret (x,
+ (λ ret => do
+ let ls ← back ret
+ return (x :: ls)))
+ ]}
+ *)
+let return_back_funs = ref true
+
(** Forbids using field projectors for structures.
If we don't use field projectors, whenever we symbolically expand a structure
diff --git a/compiler/Extract.ml b/compiler/Extract.ml
index 20cdb20b..93fcf416 100644
--- a/compiler/Extract.ml
+++ b/compiler/Extract.ml
@@ -1469,8 +1469,10 @@ let extract_fun_decl_gen (ctx : extraction_ctx) (fmt : F.formatter)
*)
let inputs_lvs =
let all_inputs = (Option.get def.body).inputs_lvs in
+ (* TODO: *)
+ assert (not !Config.return_back_funs);
let num_fwd_inputs =
- def.signature.info.num_fwd_inputs_with_fuel_with_state
+ def.signature.info.fwd_info.num_inputs_with_fuel_with_state
in
Collections.List.prefix num_fwd_inputs all_inputs
in
@@ -1515,8 +1517,10 @@ let extract_fun_decl_gen (ctx : extraction_ctx) (fmt : F.formatter)
if has_decreases_clause && !backend = Lean then (
let def_body = Option.get def.body in
let all_vars = List.map (fun (v : var) -> v.id) def_body.inputs in
+ (* TODO: *)
+ assert (not !Config.return_back_funs);
let num_fwd_inputs =
- def.signature.info.num_fwd_inputs_with_fuel_with_state
+ def.signature.info.fwd_info.num_inputs_with_fuel_with_state
in
let vars = Collections.List.prefix num_fwd_inputs all_vars in
diff --git a/compiler/InterpreterExpressions.mli b/compiler/InterpreterExpressions.mli
index f8d979f4..b975371c 100644
--- a/compiler/InterpreterExpressions.mli
+++ b/compiler/InterpreterExpressions.mli
@@ -52,7 +52,7 @@ val eval_operands :
Transmits the computed rvalue to the received continuation.
- Note that this function fails on {!constructor:Aeneas.Expressions.rvalue.Discriminant}: discriminant
+ Note that this function fails on {!Aeneas.Expressions.rvalue.Discriminant}: discriminant
reads should have been eliminated from the AST.
*)
val eval_rvalue_not_global :
diff --git a/compiler/Pure.ml b/compiler/Pure.ml
index 8d39cc69..c3716001 100644
--- a/compiler/Pure.ml
+++ b/compiler/Pure.ml
@@ -561,7 +561,7 @@ type fun_id_or_trait_method_ref =
(** A function id for a non-assumed function *)
type regular_fun_id =
- fun_id_or_trait_method_ref * LoopId.id option * T.RegionGroupId.id option
+ fun_id_or_trait_method_ref * LoopId.id option * RegionGroupId.id option
[@@deriving show, ord]
(** A function identifier *)
@@ -860,8 +860,8 @@ type fun_effect_info = {
the set [{ forward function } U { backward functions }].
We need this because of the option {!val:Config.backward_no_state_update}:
- if it is [true], then in case of a backward function {!stateful} is [false],
- but we might need to know whether the corresponding forward function
+ if it is [true], then in case of a backward function {!stateful} might be
+ [false], but we might need to know whether the corresponding forward function
is stateful or not.
*)
stateful : bool; (** [true] if the function is stateful (updates a state) *)
@@ -873,21 +873,41 @@ type fun_effect_info = {
}
[@@deriving show]
-(** Meta information about a function signature *)
-type fun_sig_info = {
+type inputs_info = {
has_fuel : bool;
- (* TODO: add [num_fwd_inputs_no_fuel_no_state] *)
- num_fwd_inputs_with_fuel_no_state : int;
- (** The number of input types for forward computation, with the fuel (if used)
+ num_inputs_no_fuel_no_state : int;
+ (** The number of input types ignoring the fuel (if used)
and ignoring the state (if used) *)
- num_fwd_inputs_with_fuel_with_state : int;
- (** The number of input types for forward computation, with fuel and state (if used) *)
- num_back_inputs_no_state : int option;
- (** The number of additional inputs for the backward computation (if pertinent),
- ignoring the state (if there is one) *)
- num_back_inputs_with_state : int option;
- (** The number of additional inputs for the backward computation (if pertinent),
- with the state (if there is one) *)
+ num_inputs_with_fuel_no_state : int;
+ (** The number of input types, with the fuel (if used)
+ and ignoring the state (if used) *)
+ num_inputs_with_fuel_with_state : int;
+ (** The number of input types, with fuel and state (if used) *)
+}
+[@@deriving show]
+
+type 'a back_info =
+ | SingleBack of 'a option
+ (** Information about a single backward function, if pertinent.
+
+ We use this variant if we split the forward and the backward functions.
+ *)
+ | AllBacks of 'a RegionGroupId.Map.t
+ (** Information about the various backward functions.
+
+ We use this if we *do not* split the forward and the backward functions.
+ All the information is then carried by the forward function.
+ *)
+[@@deriving show]
+
+type back_inputs_info = inputs_info back_info [@@deriving show]
+
+(** Meta information about a function signature *)
+type fun_sig_info = {
+ fwd_info : inputs_info;
+ (** Information about the inputs of the forward function *)
+ back_info : back_inputs_info;
+ (** Information about the inputs of the backward functions. *)
effect_info : fun_effect_info;
}
[@@deriving show]
@@ -1020,7 +1040,7 @@ type fun_decl = {
*)
loop_id : LoopId.id option;
(** [Some] if this definition was generated for a loop *)
- back_id : T.RegionGroupId.id option;
+ back_id : RegionGroupId.id option;
llbc_name : llbc_name; (** The original LLBC name. *)
name : string;
(** We use the name only for printing purposes (for debugging):
diff --git a/compiler/PureMicroPasses.ml b/compiler/PureMicroPasses.ml
index 959ec1c8..7f122f15 100644
--- a/compiler/PureMicroPasses.ml
+++ b/compiler/PureMicroPasses.ml
@@ -1326,6 +1326,9 @@ let decompose_loops (def : fun_decl) : fun_decl * fun_decl list =
let fun_sig_info = fun_sig.info in
let fun_effect_info = fun_sig_info.effect_info in
+ (* TODO: *)
+ assert (not !Config.return_back_funs);
+
(* Generate the loop definition *)
let loop_effect_info =
{
@@ -1340,36 +1343,44 @@ let decompose_loops (def : fun_decl) : fun_decl * fun_decl list =
let loop_sig_info =
let fuel = if !Config.use_fuel then 1 else 0 in
let num_inputs = List.length loop.inputs in
- let num_fwd_inputs_with_fuel_no_state = fuel + num_inputs in
- let fwd_state =
- fun_sig_info.num_fwd_inputs_with_fuel_with_state
- - fun_sig_info.num_fwd_inputs_with_fuel_no_state
- in
- let num_fwd_inputs_with_fuel_with_state =
- num_fwd_inputs_with_fuel_no_state + fwd_state
+ let fwd_info : inputs_info =
+ let info = fun_sig_info.fwd_info in
+ let fwd_state =
+ info.num_inputs_with_fuel_with_state
+ - info.num_inputs_with_fuel_no_state
+ in
+ {
+ has_fuel = !Config.use_fuel;
+ num_inputs_no_fuel_no_state = num_inputs;
+ num_inputs_with_fuel_no_state = num_inputs + fuel;
+ num_inputs_with_fuel_with_state =
+ num_inputs + fuel + fwd_state;
+ }
in
+
{
- has_fuel = !Config.use_fuel;
- num_fwd_inputs_with_fuel_no_state;
- num_fwd_inputs_with_fuel_with_state;
- num_back_inputs_no_state = fun_sig_info.num_back_inputs_no_state;
- num_back_inputs_with_state =
- fun_sig_info.num_back_inputs_with_state;
+ fwd_info;
+ back_info = fun_sig_info.back_info;
effect_info = loop_effect_info;
}
in
+ assert (fun_sig_info_is_wf loop_sig_info);
let inputs_tys =
+ (* TODO: *)
+ assert (not !Config.return_back_funs);
+
let fuel = if !Config.use_fuel then [ mk_fuel_ty ] else [] in
let fwd_inputs = List.map (fun (v : var) -> v.ty) loop.inputs in
+ let info = fun_sig_info.fwd_info in
let state =
Collections.List.subslice fun_sig.inputs
- fun_sig_info.num_fwd_inputs_with_fuel_no_state
- fun_sig_info.num_fwd_inputs_with_fuel_with_state
+ info.num_inputs_with_fuel_no_state
+ info.num_inputs_with_fuel_with_state
in
let _, back_inputs =
Collections.List.split_at fun_sig.inputs
- fun_sig_info.num_fwd_inputs_with_fuel_with_state
+ info.num_inputs_with_fuel_with_state
in
List.concat [ fuel; fwd_inputs; state; back_inputs ]
in
@@ -1430,14 +1441,17 @@ let decompose_loops (def : fun_decl) : fun_decl * fun_decl list =
in
(* Introduce the additional backward inputs *)
+ (* TODO: *)
+ assert (not !Config.return_back_funs);
let fun_body = Option.get def.body in
+ let info = fun_sig_info.fwd_info in
let _, back_inputs =
Collections.List.split_at fun_body.inputs
- fun_sig_info.num_fwd_inputs_with_fuel_with_state
+ info.num_inputs_with_fuel_with_state
in
let _, back_inputs_lvs =
Collections.List.split_at fun_body.inputs_lvs
- fun_sig_info.num_fwd_inputs_with_fuel_with_state
+ info.num_inputs_with_fuel_with_state
in
let inputs =
@@ -2053,12 +2067,14 @@ let filter_loop_inputs (transl : pure_fun_translation list) :
(* We start by computing the filtering information, for each function *)
let compute_one_filter_info (decl : fun_decl) =
+ (* TODO: *)
+ assert (not !Config.return_back_funs);
(* There should be a body *)
let body = Option.get decl.body in
(* We only look at the forward inputs, without the state *)
let inputs_prefix, _ =
Collections.List.split_at body.inputs
- decl.signature.info.num_fwd_inputs_with_fuel_no_state
+ decl.signature.info.fwd_info.num_inputs_with_fuel_no_state
in
let used = ref (List.map (fun v -> (var_get_id v, false)) inputs_prefix) in
let inputs_prefix_length = List.length inputs_prefix in
@@ -2078,7 +2094,10 @@ let filter_loop_inputs (transl : pure_fun_translation list) :
(* Set the fuel as used *)
let sg_info = decl.signature.info in
- if sg_info.has_fuel then set_used (fst (Collections.List.nth inputs 0));
+ (* TODO: *)
+ assert (not !Config.return_back_funs);
+ if sg_info.fwd_info.has_fuel then
+ set_used (fst (Collections.List.nth inputs 0));
let visitor =
object (self : 'self)
@@ -2166,31 +2185,35 @@ let filter_loop_inputs (transl : pure_fun_translation list) :
=
decl.signature
in
+ (* TODO: *)
+ assert (not !Config.return_back_funs);
+ let { fwd_info; back_info; effect_info } = info in
+
let {
has_fuel;
- num_fwd_inputs_with_fuel_no_state;
- num_fwd_inputs_with_fuel_with_state;
- num_back_inputs_no_state;
- num_back_inputs_with_state;
- effect_info;
+ num_inputs_no_fuel_no_state;
+ num_inputs_with_fuel_no_state;
+ num_inputs_with_fuel_with_state;
} =
- info
+ fwd_info
in
let inputs = filter_prefix used_info inputs in
- let info =
+ let fwd_info =
{
has_fuel;
- num_fwd_inputs_with_fuel_no_state =
- num_fwd_inputs_with_fuel_no_state - num_filtered;
- num_fwd_inputs_with_fuel_with_state =
- num_fwd_inputs_with_fuel_with_state - num_filtered;
- num_back_inputs_no_state;
- num_back_inputs_with_state;
- effect_info;
+ num_inputs_no_fuel_no_state =
+ num_inputs_no_fuel_no_state - num_filtered;
+ num_inputs_with_fuel_no_state =
+ num_inputs_with_fuel_no_state - num_filtered;
+ num_inputs_with_fuel_with_state =
+ num_inputs_with_fuel_with_state - num_filtered;
}
in
+
+ let info = { fwd_info; back_info; effect_info } in
+ assert (fun_sig_info_is_wf info);
let signature =
{ generics; llbc_generics; preds; inputs; output; doutputs; info }
in
diff --git a/compiler/PureUtils.ml b/compiler/PureUtils.ml
index 39dcd52d..3c038149 100644
--- a/compiler/PureUtils.ml
+++ b/compiler/PureUtils.ml
@@ -57,6 +57,29 @@ end
module FunLoopIdMap = Collections.MakeMap (FunLoopIdOrderedType)
module FunLoopIdSet = Collections.MakeSet (FunLoopIdOrderedType)
+let inputs_info_is_wf (info : inputs_info) : bool =
+ let {
+ has_fuel;
+ num_inputs_no_fuel_no_state;
+ num_inputs_with_fuel_no_state;
+ num_inputs_with_fuel_with_state;
+ } =
+ info
+ in
+ let fuel = if has_fuel then 1 else 0 in
+ num_inputs_no_fuel_no_state >= 0
+ && num_inputs_with_fuel_no_state = num_inputs_no_fuel_no_state + fuel
+ && num_inputs_with_fuel_with_state >= num_inputs_with_fuel_no_state
+
+let fun_sig_info_is_wf (info : fun_sig_info) : bool =
+ inputs_info_is_wf info.fwd_info
+ &&
+ match info.back_info with
+ | SingleBack None -> true
+ | SingleBack (Some info) -> inputs_info_is_wf info
+ | AllBacks infos ->
+ List.for_all inputs_info_is_wf (RegionGroupId.Map.values infos)
+
let dest_arrow_ty (ty : ty) : ty * ty =
match ty with
| TArrow (arg_ty, ret_ty) -> (arg_ty, ret_ty)
diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml
index 84f09280..1fd4896e 100644
--- a/compiler/SymbolicToPure.ml
+++ b/compiler/SymbolicToPure.ml
@@ -855,10 +855,14 @@ let get_fun_effect_info (fun_infos : fun_info A.FunDeclId.Map.t)
name (outputs for backward functions come from borrows in the inputs
of the forward function) which we use as hints to generate pretty names
in the extracted code.
+
+ We use [bid] ("backward function id") only if we split the forward
+ and the backward functions.
*)
let translate_fun_sig (decls_ctx : C.decls_ctx) (fun_id : A.fun_id)
(sg : A.fun_sig) (input_names : string option list)
(bid : T.RegionGroupId.id option) : fun_sig_named_outputs =
+ assert (Option.is_none bid || not !Config.return_back_funs);
let fun_infos = decls_ctx.fun_ctx.fun_infos in
let type_infos = decls_ctx.type_ctx.type_infos in
(* Retrieve the list of parent backward functions *)
@@ -939,6 +943,18 @@ let translate_fun_sig (decls_ctx : C.decls_ctx) (fun_id : A.fun_id)
let inside_mut = false in
translate_back_ty type_infos keep_region inside_mut ty
in
+ let translate_back_inputs_for_gid gid : ty list =
+ (* For now, we don't allow nested borrows, so the additional inputs to the
+ backward function can only come from borrows that were returned like
+ in (for the backward function we introduce for 'a):
+ {[
+ fn f<'a>(...) -> &'a mut u32;
+ ]}
+ Upon ending the abstraction for 'a, we need to get back the borrow
+ the function returned.
+ *)
+ List.filter_map (translate_back_ty_for_gid gid) [ sg.output ]
+ in
(* Compute the additinal inputs for the current function, if it is a backward
* function *)
let back_inputs =
@@ -1034,32 +1050,47 @@ let translate_fun_sig (decls_ctx : C.decls_ctx) (fun_id : A.fun_id)
(* Generic parameters *)
let generics = translate_generic_params sg.generics in
(* Return *)
+ (* TODO: *)
+ assert (not !Config.return_back_funs);
let has_fuel = fuel <> [] in
- let num_fwd_inputs_no_state = List.length fwd_inputs in
+ let num_fwd_inputs_no_fuel_no_state = List.length fwd_inputs in
let num_fwd_inputs_with_fuel_no_state =
(* We use the fact that [fuel] has length 1 if we use some fuel, 0 otherwise *)
- List.length fuel + num_fwd_inputs_no_state
+ List.length fuel + num_fwd_inputs_no_fuel_no_state
in
let num_back_inputs_no_state =
if bid = None then None else Some (List.length back_inputs)
in
- let info =
+ let fwd_info : inputs_info =
{
has_fuel;
- num_fwd_inputs_with_fuel_no_state;
- num_fwd_inputs_with_fuel_with_state =
+ num_inputs_no_fuel_no_state = num_fwd_inputs_no_fuel_no_state;
+ num_inputs_with_fuel_no_state = num_fwd_inputs_with_fuel_no_state;
+ num_inputs_with_fuel_with_state =
(* We use the fact that [fwd_state_ty] has length 1 if there is a state,
and 0 otherwise *)
num_fwd_inputs_with_fuel_no_state + List.length fwd_state_ty;
- num_back_inputs_no_state;
- num_back_inputs_with_state =
- (* Length of [back_state_ty]: similar trick as for [fwd_state_ty] *)
- Option.map
- (fun n -> n + List.length back_state_ty)
- num_back_inputs_no_state;
- effect_info;
}
in
+ let back_info : back_inputs_info =
+ if !Config.return_back_funs then
+ SingleBack
+ (Option.map
+ (fun n ->
+ (* Note that backward functions never use fuel *)
+ {
+ has_fuel = false;
+ num_inputs_no_fuel_no_state = n;
+ num_inputs_with_fuel_no_state = n;
+ (* Length of [back_state_ty]: similar trick as for [fwd_state_ty] *)
+ num_inputs_with_fuel_with_state = n + List.length back_state_ty;
+ })
+ num_back_inputs_no_state)
+ else (* Create the map *)
+ failwith "TODO"
+ in
+ let info = { fwd_info; back_info; effect_info } in
+ assert (fun_sig_info_is_wf info);
let preds = translate_predicates sg.preds in
let sg =
{
@@ -3151,14 +3182,16 @@ let translate_fun_signatures (decls_ctx : C.decls_ctx)
let fwd_id = (fun_id, None) in
(* The backward functions *)
let back_sgs =
- List.map
- (fun (rg : T.region_var_group) ->
- let tsg =
- translate_fun_sig decls_ctx fun_id sg input_names (Some rg.id)
- in
- let id = (fun_id, Some rg.id) in
- (id, tsg))
- regions_hierarchy
+ if !Config.return_back_funs then []
+ else
+ List.map
+ (fun (rg : T.region_var_group) ->
+ let tsg =
+ translate_fun_sig decls_ctx fun_id sg input_names (Some rg.id)
+ in
+ let id = (fun_id, Some rg.id) in
+ (id, tsg))
+ regions_hierarchy
in
(* Return *)
(fwd_id, fwd_sg) :: back_sgs
diff --git a/compiler/Translate.ml b/compiler/Translate.ml
index 221d4e73..54e24066 100644
--- a/compiler/Translate.ml
+++ b/compiler/Translate.ml
@@ -216,11 +216,15 @@ let translate_function_to_pure (trans_ctx : trans_ctx)
(* We need to ignore the forward inputs, and the state input (if there is) *)
let backward_inputs =
let sg = backward_sg.sg in
+ (* TODO: *)
+ assert (not !Config.return_back_funs);
(* We need to ignore the forward state and the backward state *)
let num_forward_inputs =
- sg.info.num_fwd_inputs_with_fuel_with_state
+ sg.info.fwd_info.num_inputs_with_fuel_with_state
+ in
+ let num_back_inputs =
+ (Option.get sg.info.back_info).num_inputs_no_fuel_no_state
in
- let num_back_inputs = Option.get sg.info.num_back_inputs_no_state in
Collections.List.subslice sg.inputs num_forward_inputs
(num_forward_inputs + num_back_inputs)
in