diff options
author | Son Ho | 2023-12-21 19:23:29 +0100 |
---|---|---|
committer | Son Ho | 2023-12-21 19:23:29 +0100 |
commit | 2f681446b11739e650b1d6050b717da872be9022 (patch) | |
tree | 475ca390fb80d65735590e1be600239b597e1528 | |
parent | ccfcadc3686e69c1b8a8c826ec14f3c0e1dfbd7b (diff) |
Simplify the type of the merged fwd/back functions
-rw-r--r-- | compiler/Config.ml | 26 | ||||
-rw-r--r-- | compiler/Pure.ml | 6 | ||||
-rw-r--r-- | compiler/PureMicroPasses.ml | 7 | ||||
-rw-r--r-- | compiler/PureUtils.ml | 1 | ||||
-rw-r--r-- | compiler/SymbolicToPure.ml | 159 |
5 files changed, 153 insertions, 46 deletions
diff --git a/compiler/Config.ml b/compiler/Config.ml index b8af6c6d..2bb1ca34 100644 --- a/compiler/Config.ml +++ b/compiler/Config.ml @@ -370,6 +370,32 @@ let filter_useless_monadic_calls = ref true *) let filter_useless_functions = ref true +(** Simplify the forward/backward functions, in case we merge them + (i.e., the forward functions return the backward functions). + + The simplification occurs as follows: + - if a forward function returns the unit type and has non-trivial backward + functions, then we remove the returned output. + - if a backward function doesn't have inputs, we evaluate it inside the + forward function and don't wrap it in a result. + + Example: + {[ + // LLBC: + fn incr(x: &mut u32) { *x += 1 } + + // Translation without simplification: + let incr (x : u32) : result (unit * result u32) = ... + ^^^^ ^^^^^^ + | remove this result + remove the unit + + // Translation with simplification: + let incr (x : u32) : result u32 = ... + ]} + *) +let simplify_merged_fwd_backs = ref true + (** Use short names for the record fields. Some backends can't disambiguate records when their field names have collisions. diff --git a/compiler/Pure.ml b/compiler/Pure.ml index ddacf0c4..05cdbd70 100644 --- a/compiler/Pure.ml +++ b/compiler/Pure.ml @@ -908,6 +908,11 @@ type fun_sig_info = { fwd_info : inputs_info; (** Information about the inputs of the forward function *) effect_info : fun_effect_info; + ignore_output : bool; + (** In case we merge the forward/backward functions: should we ignore + the output (happens for forward functions if the output type is + [unit] and there are non-filtered backward functions)? + *) } [@@deriving show] @@ -939,6 +944,7 @@ type back_sg_info = { We derive those from the names of the inputs of the original LLBC function. *) effect_info : fun_effect_info; + filter : bool; (** Should we filter this backward function? *) } [@@deriving show] diff --git a/compiler/PureMicroPasses.ml b/compiler/PureMicroPasses.ml index 63436e7d..16bf1c08 100644 --- a/compiler/PureMicroPasses.ml +++ b/compiler/PureMicroPasses.ml @@ -1336,6 +1336,7 @@ let decompose_loops (_ctx : trans_ctx) (def : fun_decl) : let fun_sig = def.signature in let fwd_info = fun_sig.fwd_info in let fwd_effect_info = fwd_info.effect_info in + let ignore_output = fwd_info.ignore_output in (* Generate the loop definition *) let loop_fwd_effect_info = fwd_effect_info in @@ -1358,7 +1359,7 @@ let decompose_loops (_ctx : trans_ctx) (def : fun_decl) : } in - { fwd_info; effect_info = loop_fwd_effect_info } + { fwd_info; effect_info = loop_fwd_effect_info; ignore_output } in assert (fun_sig_info_is_wf loop_fwd_sig_info); @@ -2187,7 +2188,7 @@ let filter_loop_inputs (transl : pure_fun_translation list) : } = decl.signature in - let { fwd_info; effect_info } = fwd_info in + let { fwd_info; effect_info; ignore_output } = fwd_info in let { has_fuel; @@ -2212,7 +2213,7 @@ let filter_loop_inputs (transl : pure_fun_translation list) : } in - let fwd_info = { fwd_info; effect_info } in + let fwd_info = { fwd_info; effect_info; ignore_output } in assert (fun_sig_info_is_wf fwd_info); let signature = { diff --git a/compiler/PureUtils.ml b/compiler/PureUtils.ml index d4aaba16..78d0b120 100644 --- a/compiler/PureUtils.ml +++ b/compiler/PureUtils.ml @@ -448,6 +448,7 @@ let mk_simpl_tuple_ty (tys : ty list) : ty = let mk_bool_ty : ty = TLiteral TBool let mk_unit_ty : ty = TAdt (TTuple, empty_generic_args) +let ty_is_unit ty : bool = ty = mk_unit_ty let mk_unit_rvalue : texpression = let id = AdtCons { adt_id = TTuple; variant_id = None } in diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml index f37ea201..70a4e18d 100644 --- a/compiler/SymbolicToPure.ml +++ b/compiler/SymbolicToPure.ml @@ -979,30 +979,6 @@ let translate_fun_sig_with_regions_hierarchy_to_decomposed in (* Compute the backward output, without the effect information *) let fwd_output = translate_fwd_ty type_infos sg.output in - (* The additinoal information *) - let fwd_info = - (* *) - let has_fuel = fwd_fuel <> [] in - let num_inputs_no_fuel_no_state = List.length fwd_inputs_no_fuel_no_state in - let num_inputs_with_fuel_no_state = - (* We use the fact that [fuel] has length 1 if we use some fuel, 0 otherwise *) - List.length fwd_fuel + num_inputs_no_fuel_no_state - in - let fwd_info : inputs_info = - { - has_fuel; - num_inputs_no_fuel_no_state; - num_inputs_with_fuel_no_state; - num_inputs_with_fuel_with_state = - (* We use the fact that [fwd_state_ty] has length 1 if there is a state, - and 0 otherwise *) - num_inputs_with_fuel_no_state + List.length fwd_state_ty; - } - in - let info = { fwd_info; effect_info = fwd_effect_info } in - assert (fun_sig_info_is_wf info); - info - in (* Compute the type information for the backward function *) (* Small helper to translate types for backward functions *) @@ -1086,6 +1062,9 @@ let translate_fun_sig_with_regions_hierarchy_to_decomposed in let inputs = inputs_no_state @ state in let output_names, outputs = compute_back_outputs_for_gid gid in + let filter = + !Config.simplify_merged_fwd_backs && inputs = [] && outputs = [] + in let info = { inputs; @@ -1093,6 +1072,7 @@ let translate_fun_sig_with_regions_hierarchy_to_decomposed outputs; output_names; effect_info = back_effect_info; + filter; } in (gid, info) @@ -1102,6 +1082,39 @@ let translate_fun_sig_with_regions_hierarchy_to_decomposed (List.map compute_back_info_for_group regions_hierarchy) in + (* The additional information about the forward function *) + let fwd_info = + (* *) + let has_fuel = fwd_fuel <> [] in + let num_inputs_no_fuel_no_state = List.length fwd_inputs_no_fuel_no_state in + let num_inputs_with_fuel_no_state = + (* We use the fact that [fuel] has length 1 if we use some fuel, 0 otherwise *) + List.length fwd_fuel + num_inputs_no_fuel_no_state + in + let fwd_info : inputs_info = + { + has_fuel; + num_inputs_no_fuel_no_state; + num_inputs_with_fuel_no_state; + num_inputs_with_fuel_with_state = + (* We use the fact that [fwd_state_ty] has length 1 if there is a state, + and 0 otherwise *) + num_inputs_with_fuel_no_state + List.length fwd_state_ty; + } + in + let ignore_output = + if !Config.return_back_funs && !Config.simplify_merged_fwd_backs then + ty_is_unit fwd_output + && List.exists + (fun (info : back_sg_info) -> not info.filter) + (RegionGroupId.Map.values back_sg) + else false + in + let info = { fwd_info; effect_info = fwd_effect_info; ignore_output } in + assert (fun_sig_info_is_wf info); + info + in + (* Generic parameters *) let generics = translate_generic_params sg.generics in @@ -1134,6 +1147,13 @@ let mk_output_ty_from_effect_info (effect_info : fun_effect_info) (ty : ty) : ty in if effect_info.can_fail then mk_result_ty output else output +let mk_back_output_ty_from_effect_info (effect_info : fun_effect_info) + (inputs : ty list) (ty : ty) : ty = + let output = + if effect_info.stateful then mk_simpl_tuple_ty [ mk_state_ty; ty ] else ty + in + if effect_info.can_fail && inputs <> [] then mk_result_ty output else output + (** Compute the arrow types for all the backward functions. If a backward function has no inputs/outputs we filter it. @@ -1151,7 +1171,9 @@ let compute_back_tys (dsg : Pure.decomposed_fun_sig) None else let output = mk_simpl_tuple_ty outputs in - let output = mk_output_ty_from_effect_info effect_info output in + let output = + mk_back_output_ty_from_effect_info effect_info inputs output + in let ty = mk_arrows inputs output in (* Substitute - TODO: normalize *) let ty = @@ -1166,6 +1188,25 @@ let compute_back_tys (dsg : Pure.decomposed_fun_sig) Some ty) (RegionGroupId.Map.values dsg.back_sg) +(** In case we merge the fwd/back functions: compute the output type of + a function, from a decomposed signature. *) +let compute_output_ty_from_decomposed (dsg : Pure.decomposed_fun_sig) : ty = + assert !Config.return_back_funs; + (* Compute the arrow types for all the backward functions *) + let back_tys = List.filter_map (fun x -> x) (compute_back_tys dsg None) in + (* Group the forward output and the types of the backward functions *) + let effect_info = dsg.fwd_info.effect_info in + let output = + (* We might need to ignore the output of the forward function + (if it is unit for instance) *) + let tys = + if dsg.fwd_info.ignore_output then back_tys + else dsg.fwd_output :: back_tys + in + mk_simpl_tuple_ty tys + in + mk_output_ty_from_effect_info effect_info output + let translate_fun_sig_from_decomposed (dsg : Pure.decomposed_fun_sig) (gid : RegionGroupId.id option) : fun_sig = let generics = dsg.generics in @@ -1180,19 +1221,12 @@ let translate_fun_sig_from_decomposed (dsg : Pure.decomposed_fun_sig) (gid, info.effect_info)) (RegionGroupId.Map.bindings dsg.back_sg)) in - (* Two cases depending on whether we split the forward/backward functions - or not *) let mk_output_ty = mk_output_ty_from_effect_info in - let inputs, output = + (* Two cases depending on whether we split the forward/backward functions or not *) if !Config.return_back_funs then ( assert (gid = None); - (* Compute the arrow types for all the backward functions *) - let back_tys = List.filter_map (fun x -> x) (compute_back_tys dsg None) in - (* Group the forward output and the types of the backward functions *) - let effect_info = dsg.fwd_info.effect_info in - let output = mk_simpl_tuple_ty (dsg.fwd_output :: back_tys) in - let output = mk_output_ty effect_info output in + let output = compute_output_ty_from_decomposed dsg in let inputs = dsg.fwd_inputs in (inputs, output)) else @@ -1785,7 +1819,11 @@ and translate_panic (ctx : bs_ctx) : texpression = if !Config.return_back_funs then let back_tys = compute_back_tys ctx.sg None in let back_tys = List.filter_map (fun x -> x) back_tys in - let output = mk_simpl_tuple_ty (ctx.sg.fwd_output :: back_tys) in + let tys = + if ctx.sg.fwd_info.ignore_output then back_tys + else ctx.sg.fwd_output :: back_tys + in + let output = mk_simpl_tuple_ty tys in mk_output output else mk_output ctx.sg.fwd_output | Some bid -> @@ -1798,6 +1836,9 @@ and translate_panic (ctx : bs_ctx) : texpression = Remark: for now, we can't get there if we are inside a loop. If inside a loop, we use {!translate_return_with_loop}. + + Remark: in case we merge the forward/backward functions, we introduce + those in [translate_forward_end]. *) and translate_return (ectx : C.eval_ctx) (opt_v : V.typed_value option) (ctx : bs_ctx) : texpression = @@ -2648,6 +2689,12 @@ and translate_expansion (p : S.mplace option) (sv : V.symbolic_value) If (true_e, false_e) ) in let ty = true_e.ty in + log#ldebug + (lazy + ("true_e.ty: " + ^ pure_ty_to_string ctx true_e.ty + ^ "\n\nfalse_e.ty: " + ^ pure_ty_to_string ctx false_e.ty)); assert (ty = false_e.ty); { e; ty } | ExpandInt (int_ty, branches, otherwise) -> @@ -2941,37 +2988,63 @@ and translate_forward_end (ectx : C.eval_ctx) in let fwd_e = translate_one_end ctx None in - (* Introduce the backward functions *) + (* Introduce the backward functions. *) let back_el = List.map (fun ((gid, _) : RegionGroupId.id * back_sg_info) -> translate_one_end ctx (Some gid)) (RegionGroupId.Map.bindings ctx.sg.back_sg) in + + (* Compute whether the backward expressions should be evaluated straight + away or not (i.e., if we should bind them with monadic let-bindings + or not). We evaluate them straight away if they can fail and have no + inputs *) + let evaluate_backs = + List.map + (fun (sg : back_sg_info) -> + if !Config.simplify_merged_fwd_backs then + sg.inputs = [] && sg.effect_info.can_fail + else false) + (RegionGroupId.Map.values ctx.sg.back_sg) + in + (* Introduce variables for the backward functions. We lookup the LLBC definition in an attempt to derive pretty names for those functions. *) let _, back_vars = fresh_back_vars_for_current_fun ctx in (* Create the return expressions *) - let vars = fwd_var :: List.filter_map (fun x -> x) back_vars in + let vars = + let back_vars = List.filter_map (fun x -> x) back_vars in + if ctx.sg.fwd_info.ignore_output then back_vars + else fwd_var :: back_vars + in let vars = List.map mk_texpression_from_var vars in let ret = mk_simpl_tuple_texpression vars in let state_var = List.map mk_texpression_from_var state_var in let ret = mk_simpl_tuple_texpression (state_var @ [ ret ]) in let ret = mk_result_return_texpression ret in - (* Bind the expressions for the backward function and the expression - for the computation of the forward output *) + (* Introduce all the let-bindings *) + + (* Combine: + - the backward variables + - whether we should evaluate the expression for the backward function + (i.e., should we use a monadic let-binding or not - we do if the + backward functions don't have inputs and can fail) + - the expressions for the backward functions + *) let back_vars_els = List.filter_map - (fun (v, el) -> match v with None -> None | Some v -> Some (v, el)) - (List.combine back_vars back_el) + (fun (v, (eval, el)) -> + match v with None -> None | Some v -> Some (v, eval, el)) + (List.combine back_vars (List.combine evaluate_backs back_el)) in let e = List.fold_right - (fun (var, back_e) e -> - mk_let false (mk_typed_pattern_from_var var None) back_e e) + (fun (var, evaluate, back_e) e -> + mk_let evaluate (mk_typed_pattern_from_var var None) back_e e) back_vars_els ret in (* Bind the expression for the forward output *) |