summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSon Ho2023-12-21 19:23:29 +0100
committerSon Ho2023-12-21 19:23:29 +0100
commit2f681446b11739e650b1d6050b717da872be9022 (patch)
tree475ca390fb80d65735590e1be600239b597e1528
parentccfcadc3686e69c1b8a8c826ec14f3c0e1dfbd7b (diff)
Simplify the type of the merged fwd/back functions
-rw-r--r--compiler/Config.ml26
-rw-r--r--compiler/Pure.ml6
-rw-r--r--compiler/PureMicroPasses.ml7
-rw-r--r--compiler/PureUtils.ml1
-rw-r--r--compiler/SymbolicToPure.ml159
5 files changed, 153 insertions, 46 deletions
diff --git a/compiler/Config.ml b/compiler/Config.ml
index b8af6c6d..2bb1ca34 100644
--- a/compiler/Config.ml
+++ b/compiler/Config.ml
@@ -370,6 +370,32 @@ let filter_useless_monadic_calls = ref true
*)
let filter_useless_functions = ref true
+(** Simplify the forward/backward functions, in case we merge them
+ (i.e., the forward functions return the backward functions).
+
+ The simplification occurs as follows:
+ - if a forward function returns the unit type and has non-trivial backward
+ functions, then we remove the returned output.
+ - if a backward function doesn't have inputs, we evaluate it inside the
+ forward function and don't wrap it in a result.
+
+ Example:
+ {[
+ // LLBC:
+ fn incr(x: &mut u32) { *x += 1 }
+
+ // Translation without simplification:
+ let incr (x : u32) : result (unit * result u32) = ...
+ ^^^^ ^^^^^^
+ | remove this result
+ remove the unit
+
+ // Translation with simplification:
+ let incr (x : u32) : result u32 = ...
+ ]}
+ *)
+let simplify_merged_fwd_backs = ref true
+
(** Use short names for the record fields.
Some backends can't disambiguate records when their field names have collisions.
diff --git a/compiler/Pure.ml b/compiler/Pure.ml
index ddacf0c4..05cdbd70 100644
--- a/compiler/Pure.ml
+++ b/compiler/Pure.ml
@@ -908,6 +908,11 @@ type fun_sig_info = {
fwd_info : inputs_info;
(** Information about the inputs of the forward function *)
effect_info : fun_effect_info;
+ ignore_output : bool;
+ (** In case we merge the forward/backward functions: should we ignore
+ the output (happens for forward functions if the output type is
+ [unit] and there are non-filtered backward functions)?
+ *)
}
[@@deriving show]
@@ -939,6 +944,7 @@ type back_sg_info = {
We derive those from the names of the inputs of the original LLBC
function. *)
effect_info : fun_effect_info;
+ filter : bool; (** Should we filter this backward function? *)
}
[@@deriving show]
diff --git a/compiler/PureMicroPasses.ml b/compiler/PureMicroPasses.ml
index 63436e7d..16bf1c08 100644
--- a/compiler/PureMicroPasses.ml
+++ b/compiler/PureMicroPasses.ml
@@ -1336,6 +1336,7 @@ let decompose_loops (_ctx : trans_ctx) (def : fun_decl) :
let fun_sig = def.signature in
let fwd_info = fun_sig.fwd_info in
let fwd_effect_info = fwd_info.effect_info in
+ let ignore_output = fwd_info.ignore_output in
(* Generate the loop definition *)
let loop_fwd_effect_info = fwd_effect_info in
@@ -1358,7 +1359,7 @@ let decompose_loops (_ctx : trans_ctx) (def : fun_decl) :
}
in
- { fwd_info; effect_info = loop_fwd_effect_info }
+ { fwd_info; effect_info = loop_fwd_effect_info; ignore_output }
in
assert (fun_sig_info_is_wf loop_fwd_sig_info);
@@ -2187,7 +2188,7 @@ let filter_loop_inputs (transl : pure_fun_translation list) :
} =
decl.signature
in
- let { fwd_info; effect_info } = fwd_info in
+ let { fwd_info; effect_info; ignore_output } = fwd_info in
let {
has_fuel;
@@ -2212,7 +2213,7 @@ let filter_loop_inputs (transl : pure_fun_translation list) :
}
in
- let fwd_info = { fwd_info; effect_info } in
+ let fwd_info = { fwd_info; effect_info; ignore_output } in
assert (fun_sig_info_is_wf fwd_info);
let signature =
{
diff --git a/compiler/PureUtils.ml b/compiler/PureUtils.ml
index d4aaba16..78d0b120 100644
--- a/compiler/PureUtils.ml
+++ b/compiler/PureUtils.ml
@@ -448,6 +448,7 @@ let mk_simpl_tuple_ty (tys : ty list) : ty =
let mk_bool_ty : ty = TLiteral TBool
let mk_unit_ty : ty = TAdt (TTuple, empty_generic_args)
+let ty_is_unit ty : bool = ty = mk_unit_ty
let mk_unit_rvalue : texpression =
let id = AdtCons { adt_id = TTuple; variant_id = None } in
diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml
index f37ea201..70a4e18d 100644
--- a/compiler/SymbolicToPure.ml
+++ b/compiler/SymbolicToPure.ml
@@ -979,30 +979,6 @@ let translate_fun_sig_with_regions_hierarchy_to_decomposed
in
(* Compute the backward output, without the effect information *)
let fwd_output = translate_fwd_ty type_infos sg.output in
- (* The additinoal information *)
- let fwd_info =
- (* *)
- let has_fuel = fwd_fuel <> [] in
- let num_inputs_no_fuel_no_state = List.length fwd_inputs_no_fuel_no_state in
- let num_inputs_with_fuel_no_state =
- (* We use the fact that [fuel] has length 1 if we use some fuel, 0 otherwise *)
- List.length fwd_fuel + num_inputs_no_fuel_no_state
- in
- let fwd_info : inputs_info =
- {
- has_fuel;
- num_inputs_no_fuel_no_state;
- num_inputs_with_fuel_no_state;
- num_inputs_with_fuel_with_state =
- (* We use the fact that [fwd_state_ty] has length 1 if there is a state,
- and 0 otherwise *)
- num_inputs_with_fuel_no_state + List.length fwd_state_ty;
- }
- in
- let info = { fwd_info; effect_info = fwd_effect_info } in
- assert (fun_sig_info_is_wf info);
- info
- in
(* Compute the type information for the backward function *)
(* Small helper to translate types for backward functions *)
@@ -1086,6 +1062,9 @@ let translate_fun_sig_with_regions_hierarchy_to_decomposed
in
let inputs = inputs_no_state @ state in
let output_names, outputs = compute_back_outputs_for_gid gid in
+ let filter =
+ !Config.simplify_merged_fwd_backs && inputs = [] && outputs = []
+ in
let info =
{
inputs;
@@ -1093,6 +1072,7 @@ let translate_fun_sig_with_regions_hierarchy_to_decomposed
outputs;
output_names;
effect_info = back_effect_info;
+ filter;
}
in
(gid, info)
@@ -1102,6 +1082,39 @@ let translate_fun_sig_with_regions_hierarchy_to_decomposed
(List.map compute_back_info_for_group regions_hierarchy)
in
+ (* The additional information about the forward function *)
+ let fwd_info =
+ (* *)
+ let has_fuel = fwd_fuel <> [] in
+ let num_inputs_no_fuel_no_state = List.length fwd_inputs_no_fuel_no_state in
+ let num_inputs_with_fuel_no_state =
+ (* We use the fact that [fuel] has length 1 if we use some fuel, 0 otherwise *)
+ List.length fwd_fuel + num_inputs_no_fuel_no_state
+ in
+ let fwd_info : inputs_info =
+ {
+ has_fuel;
+ num_inputs_no_fuel_no_state;
+ num_inputs_with_fuel_no_state;
+ num_inputs_with_fuel_with_state =
+ (* We use the fact that [fwd_state_ty] has length 1 if there is a state,
+ and 0 otherwise *)
+ num_inputs_with_fuel_no_state + List.length fwd_state_ty;
+ }
+ in
+ let ignore_output =
+ if !Config.return_back_funs && !Config.simplify_merged_fwd_backs then
+ ty_is_unit fwd_output
+ && List.exists
+ (fun (info : back_sg_info) -> not info.filter)
+ (RegionGroupId.Map.values back_sg)
+ else false
+ in
+ let info = { fwd_info; effect_info = fwd_effect_info; ignore_output } in
+ assert (fun_sig_info_is_wf info);
+ info
+ in
+
(* Generic parameters *)
let generics = translate_generic_params sg.generics in
@@ -1134,6 +1147,13 @@ let mk_output_ty_from_effect_info (effect_info : fun_effect_info) (ty : ty) : ty
in
if effect_info.can_fail then mk_result_ty output else output
+let mk_back_output_ty_from_effect_info (effect_info : fun_effect_info)
+ (inputs : ty list) (ty : ty) : ty =
+ let output =
+ if effect_info.stateful then mk_simpl_tuple_ty [ mk_state_ty; ty ] else ty
+ in
+ if effect_info.can_fail && inputs <> [] then mk_result_ty output else output
+
(** Compute the arrow types for all the backward functions.
If a backward function has no inputs/outputs we filter it.
@@ -1151,7 +1171,9 @@ let compute_back_tys (dsg : Pure.decomposed_fun_sig)
None
else
let output = mk_simpl_tuple_ty outputs in
- let output = mk_output_ty_from_effect_info effect_info output in
+ let output =
+ mk_back_output_ty_from_effect_info effect_info inputs output
+ in
let ty = mk_arrows inputs output in
(* Substitute - TODO: normalize *)
let ty =
@@ -1166,6 +1188,25 @@ let compute_back_tys (dsg : Pure.decomposed_fun_sig)
Some ty)
(RegionGroupId.Map.values dsg.back_sg)
+(** In case we merge the fwd/back functions: compute the output type of
+ a function, from a decomposed signature. *)
+let compute_output_ty_from_decomposed (dsg : Pure.decomposed_fun_sig) : ty =
+ assert !Config.return_back_funs;
+ (* Compute the arrow types for all the backward functions *)
+ let back_tys = List.filter_map (fun x -> x) (compute_back_tys dsg None) in
+ (* Group the forward output and the types of the backward functions *)
+ let effect_info = dsg.fwd_info.effect_info in
+ let output =
+ (* We might need to ignore the output of the forward function
+ (if it is unit for instance) *)
+ let tys =
+ if dsg.fwd_info.ignore_output then back_tys
+ else dsg.fwd_output :: back_tys
+ in
+ mk_simpl_tuple_ty tys
+ in
+ mk_output_ty_from_effect_info effect_info output
+
let translate_fun_sig_from_decomposed (dsg : Pure.decomposed_fun_sig)
(gid : RegionGroupId.id option) : fun_sig =
let generics = dsg.generics in
@@ -1180,19 +1221,12 @@ let translate_fun_sig_from_decomposed (dsg : Pure.decomposed_fun_sig)
(gid, info.effect_info))
(RegionGroupId.Map.bindings dsg.back_sg))
in
- (* Two cases depending on whether we split the forward/backward functions
- or not *)
let mk_output_ty = mk_output_ty_from_effect_info in
-
let inputs, output =
+ (* Two cases depending on whether we split the forward/backward functions or not *)
if !Config.return_back_funs then (
assert (gid = None);
- (* Compute the arrow types for all the backward functions *)
- let back_tys = List.filter_map (fun x -> x) (compute_back_tys dsg None) in
- (* Group the forward output and the types of the backward functions *)
- let effect_info = dsg.fwd_info.effect_info in
- let output = mk_simpl_tuple_ty (dsg.fwd_output :: back_tys) in
- let output = mk_output_ty effect_info output in
+ let output = compute_output_ty_from_decomposed dsg in
let inputs = dsg.fwd_inputs in
(inputs, output))
else
@@ -1785,7 +1819,11 @@ and translate_panic (ctx : bs_ctx) : texpression =
if !Config.return_back_funs then
let back_tys = compute_back_tys ctx.sg None in
let back_tys = List.filter_map (fun x -> x) back_tys in
- let output = mk_simpl_tuple_ty (ctx.sg.fwd_output :: back_tys) in
+ let tys =
+ if ctx.sg.fwd_info.ignore_output then back_tys
+ else ctx.sg.fwd_output :: back_tys
+ in
+ let output = mk_simpl_tuple_ty tys in
mk_output output
else mk_output ctx.sg.fwd_output
| Some bid ->
@@ -1798,6 +1836,9 @@ and translate_panic (ctx : bs_ctx) : texpression =
Remark: for now, we can't get there if we are inside a loop.
If inside a loop, we use {!translate_return_with_loop}.
+
+ Remark: in case we merge the forward/backward functions, we introduce
+ those in [translate_forward_end].
*)
and translate_return (ectx : C.eval_ctx) (opt_v : V.typed_value option)
(ctx : bs_ctx) : texpression =
@@ -2648,6 +2689,12 @@ and translate_expansion (p : S.mplace option) (sv : V.symbolic_value)
If (true_e, false_e) )
in
let ty = true_e.ty in
+ log#ldebug
+ (lazy
+ ("true_e.ty: "
+ ^ pure_ty_to_string ctx true_e.ty
+ ^ "\n\nfalse_e.ty: "
+ ^ pure_ty_to_string ctx false_e.ty));
assert (ty = false_e.ty);
{ e; ty }
| ExpandInt (int_ty, branches, otherwise) ->
@@ -2941,37 +2988,63 @@ and translate_forward_end (ectx : C.eval_ctx)
in
let fwd_e = translate_one_end ctx None in
- (* Introduce the backward functions *)
+ (* Introduce the backward functions. *)
let back_el =
List.map
(fun ((gid, _) : RegionGroupId.id * back_sg_info) ->
translate_one_end ctx (Some gid))
(RegionGroupId.Map.bindings ctx.sg.back_sg)
in
+
+ (* Compute whether the backward expressions should be evaluated straight
+ away or not (i.e., if we should bind them with monadic let-bindings
+ or not). We evaluate them straight away if they can fail and have no
+ inputs *)
+ let evaluate_backs =
+ List.map
+ (fun (sg : back_sg_info) ->
+ if !Config.simplify_merged_fwd_backs then
+ sg.inputs = [] && sg.effect_info.can_fail
+ else false)
+ (RegionGroupId.Map.values ctx.sg.back_sg)
+ in
+
(* Introduce variables for the backward functions.
We lookup the LLBC definition in an attempt to derive pretty names
for those functions. *)
let _, back_vars = fresh_back_vars_for_current_fun ctx in
(* Create the return expressions *)
- let vars = fwd_var :: List.filter_map (fun x -> x) back_vars in
+ let vars =
+ let back_vars = List.filter_map (fun x -> x) back_vars in
+ if ctx.sg.fwd_info.ignore_output then back_vars
+ else fwd_var :: back_vars
+ in
let vars = List.map mk_texpression_from_var vars in
let ret = mk_simpl_tuple_texpression vars in
let state_var = List.map mk_texpression_from_var state_var in
let ret = mk_simpl_tuple_texpression (state_var @ [ ret ]) in
let ret = mk_result_return_texpression ret in
- (* Bind the expressions for the backward function and the expression
- for the computation of the forward output *)
+ (* Introduce all the let-bindings *)
+
+ (* Combine:
+ - the backward variables
+ - whether we should evaluate the expression for the backward function
+ (i.e., should we use a monadic let-binding or not - we do if the
+ backward functions don't have inputs and can fail)
+ - the expressions for the backward functions
+ *)
let back_vars_els =
List.filter_map
- (fun (v, el) -> match v with None -> None | Some v -> Some (v, el))
- (List.combine back_vars back_el)
+ (fun (v, (eval, el)) ->
+ match v with None -> None | Some v -> Some (v, eval, el))
+ (List.combine back_vars (List.combine evaluate_backs back_el))
in
let e =
List.fold_right
- (fun (var, back_e) e ->
- mk_let false (mk_typed_pattern_from_var var None) back_e e)
+ (fun (var, evaluate, back_e) e ->
+ mk_let evaluate (mk_typed_pattern_from_var var None) back_e e)
back_vars_els ret
in
(* Bind the expression for the forward output *)